node_classification.py 6.13 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
import argparse

import dgl
import dgl.nn as dglnn
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
import tqdm
10
from dgl.data import AsNodePredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
13
14
15
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
16
from ogb.nodeproppred import DglNodePropPredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17

18
19

class SAGE(nn.Module):
20
    def __init__(self, in_size, hid_size, out_size):
21
22
        super().__init__()
        self.layers = nn.ModuleList()
23
        # three-layer GraphSAGE-mean
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
25
26
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
27
        self.dropout = nn.Dropout(0.5)
28
29
        self.hid_size = hid_size
        self.out_size = out_size
30
31
32
33
34
35
36
37
38
39

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

40
41
    def inference(self, g, device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
42
43
        feat = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
44
        dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
46
47
48
49
50
51
52
53
54
55
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        pin_memory = buffer_device != device
56
57

        for l, layer in enumerate(self.layers):
58
            y = torch.empty(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
59
60
61
62
63
                g.num_nodes(),
                self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device,
                pin_memory=pin_memory,
            )
64
            feat = feat.to(device)
65
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
66
                x = feat[input_nodes]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
                h = layer(blocks[0], x)  # len(blocks) = 1
68
69
70
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
71
                # by design, our output nodes are contiguous
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
72
                y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
73
            feat = y
74
        return y
75

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
76

77
def evaluate(model, graph, dataloader, num_classes):
78
79
80
    model.eval()
    ys = []
    y_hats = []
81
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
82
        with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
83
84
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
85
            y_hats.append(model(blocks, x))
86
87
88
89
90
91
    return MF.accuracy(
        torch.cat(y_hats),
        torch.cat(ys),
        task="multiclass",
        num_classes=num_classes,
    )
92

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
93

94
def layerwise_infer(device, graph, nid, model, num_classes, batch_size):
95
96
    model.eval()
    with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
97
98
99
        pred = model.inference(
            graph, device, batch_size
        )  # pred in buffer_device
100
        pred = pred[nid]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
101
        label = graph.ndata["label"][nid].to(pred.device)
102
103
104
        return MF.accuracy(
            pred, label, task="multiclass", num_classes=num_classes
        )
105

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106

107
def train(args, device, g, dataset, model, num_classes):
108
109
110
    # create sampler & dataloader
    train_idx = dataset.train_idx.to(device)
    val_idx = dataset.val_idx.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    sampler = NeighborSampler(
        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]
        prefetch_node_feats=["feat"],
        prefetch_labels=["label"],
    )
    use_uva = args.mode == "mixed"
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )

    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )
140
141

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
142

143
144
145
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
146
147
148
149
150
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
151
152
153
154
155
156
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
157
        acc = evaluate(model, g, val_dataloader, num_classes)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
158
159
160
161
162
163
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, total_loss / (it + 1), acc.item()
            )
        )

164

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
165
if __name__ == "__main__":
166
    parser = argparse.ArgumentParser()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
168
169
170
171
172
173
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["cpu", "mixed", "puregpu"],
        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
174
175
    args = parser.parse_args()
    if not torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
176
177
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")
178

179
    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
180
181
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
182
    g = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
183
    g = g.to("cuda" if args.mode == "puregpu" else "cpu")
184
    num_classes = dataset.num_classes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
185
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
186
187

    # create GraphSAGE model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
188
    in_size = g.ndata["feat"].shape[1]
189
190
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size).to(device)
191

192
    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
193
    print("Training...")
194
    train(args, device, g, dataset, model, num_classes)
195

196
    # test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
197
    print("Testing...")
198
199
200
    acc = layerwise_infer(
        device, g, dataset.test_idx, model, num_classes, batch_size=4096
    )
201
    print("Test Accuracy {:.4f}".format(acc.item()))