node_classification.py 6.45 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
import argparse

import dgl
import dgl.nn as dglnn
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
import tqdm
10
from dgl.data import AsNodePredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
13
14
15
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
16
from ogb.nodeproppred import DglNodePropPredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17

18
19

class SAGE(nn.Module):
20
    def __init__(self, in_size, hid_size, out_size):
21
22
        super().__init__()
        self.layers = nn.ModuleList()
23
        # three-layer GraphSAGE-mean
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
25
26
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
27
        self.dropout = nn.Dropout(0.5)
28
29
        self.hid_size = hid_size
        self.out_size = out_size
30
31
32
33
34
35
36
37
38
39

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

40
41
    def inference(self, g, device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
42
43
        feat = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
44
        dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
46
47
48
49
50
51
52
53
54
55
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        pin_memory = buffer_device != device
56
57

        for l, layer in enumerate(self.layers):
58
            y = torch.empty(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
59
60
                g.num_nodes(),
                self.hid_size if l != len(self.layers) - 1 else self.out_size,
61
                dtype=feat.dtype,
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
63
64
                device=buffer_device,
                pin_memory=pin_memory,
            )
65
            feat = feat.to(device)
66
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
67
                x = feat[input_nodes]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
68
                h = layer(blocks[0], x)  # len(blocks) = 1
69
70
71
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
72
                # by design, our output nodes are contiguous
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
73
                y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
74
            feat = y
75
        return y
76

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77

78
def evaluate(model, graph, dataloader, num_classes):
79
80
81
    model.eval()
    ys = []
    y_hats = []
82
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
83
        with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
84
85
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
86
            y_hats.append(model(blocks, x))
87
88
89
90
91
92
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )
93

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
94

95
def layerwise_infer(device, graph, nid, model, num_classes, batch_size):
96
97
    model.eval()
    with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
98
99
100
        pred = model.inference(
            graph, device, batch_size
        )  # pred in buffer_device
101
        pred = pred[nid]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
102
        label = graph.ndata["label"][nid].to(pred.device)
103
104
105
        return MF.accuracy(
            pred, label, task="multiclass", num_classes=num_classes
        )
106

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
107

108
def train(args, device, g, dataset, model, num_classes):
109
110
111
    # create sampler & dataloader
    train_idx = dataset.train_idx.to(device)
    val_idx = dataset.val_idx.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    sampler = NeighborSampler(
        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]
        prefetch_node_feats=["feat"],
        prefetch_labels=["label"],
    )
    use_uva = args.mode == "mixed"
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )

    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )
141
142

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
143

144
145
146
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
147
148
149
150
151
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
152
153
154
155
156
157
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
158
        acc = evaluate(model, g, val_dataloader, num_classes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
159
160
161
162
163
164
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, total_loss / (it + 1), acc.item()
            )
        )

165

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
166
if __name__ == "__main__":
167
    parser = argparse.ArgumentParser()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
168
169
170
171
172
173
174
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["cpu", "mixed", "puregpu"],
        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
175
176
177
178
179
180
    parser.add_argument(
        "--dt",
        type=str,
        default="float",
        help="data type(float, bfloat16)",
    )
181
182
    args = parser.parse_args()
    if not torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
183
184
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")
185

186
    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
187
188
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
189
    g = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
190
    g = g.to("cuda" if args.mode == "puregpu" else "cpu")
191
    num_classes = dataset.num_classes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
192
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
193
194

    # create GraphSAGE model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
195
    in_size = g.ndata["feat"].shape[1]
196
197
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size).to(device)
198

199
200
201
202
203
    # convert model and graph to bfloat16 if needed
    if args.dt == "bfloat16":
        g = dgl.to_bfloat16(g)
        model = model.to(dtype=torch.bfloat16)

204
    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
205
    print("Training...")
206
    train(args, device, g, dataset, model, num_classes)
207

208
    # test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
209
    print("Testing...")
210
211
212
    acc = layerwise_infer(
        device, g, dataset.test_idx, model, num_classes, batch_size=4096
    )
213
    print("Test Accuracy {:.4f}".format(acc.item()))