node_classification.py 5.85 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
4
import argparse

import dgl
import dgl.nn as dglnn
5
6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
import tqdm
10
from dgl.data import AsNodePredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
11
12
13
14
15
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    NeighborSampler,
)
16
from ogb.nodeproppred import DglNodePropPredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17

18
19

class SAGE(nn.Module):
20
    def __init__(self, in_size, hid_size, out_size):
21
22
        super().__init__()
        self.layers = nn.ModuleList()
23
        # three-layer GraphSAGE-mean
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
24
25
26
        self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
27
        self.dropout = nn.Dropout(0.5)
28
29
        self.hid_size = hid_size
        self.out_size = out_size
30
31
32
33
34
35
36
37
38
39

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

40
41
    def inference(self, g, device, batch_size):
        """Conduct layer-wise inference to get all the node embeddings."""
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
42
43
        feat = g.ndata["feat"]
        sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
44
        dataloader = DataLoader(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
45
46
47
48
49
50
51
52
53
54
55
            g,
            torch.arange(g.num_nodes()).to(g.device),
            sampler,
            device=device,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,
        )
        buffer_device = torch.device("cpu")
        pin_memory = buffer_device != device
56
57

        for l, layer in enumerate(self.layers):
58
            y = torch.empty(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
59
60
61
62
63
                g.num_nodes(),
                self.hid_size if l != len(self.layers) - 1 else self.out_size,
                device=buffer_device,
                pin_memory=pin_memory,
            )
64
            feat = feat.to(device)
65
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
66
                x = feat[input_nodes]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
67
                h = layer(blocks[0], x)  # len(blocks) = 1
68
69
70
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
71
                # by design, our output nodes are contiguous
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
72
                y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
73
            feat = y
74
        return y
75

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
76

77
def evaluate(model, graph, dataloader):
78
79
80
    model.eval()
    ys = []
    y_hats = []
81
    for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
82
        with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
83
84
            x = blocks[0].srcdata["feat"]
            ys.append(blocks[-1].dstdata["label"])
85
            y_hats.append(model(blocks, x))
86
87
    return MF.accuracy(torch.cat(y_hats), torch.cat(ys))

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
88

89
def layerwise_infer(device, graph, nid, model, batch_size):
90
91
    model.eval()
    with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
92
93
94
        pred = model.inference(
            graph, device, batch_size
        )  # pred in buffer_device
95
        pred = pred[nid]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
96
        label = graph.ndata["label"][nid].to(pred.device)
97
98
        return MF.accuracy(pred, label)

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99

100
101
102
103
def train(args, device, g, dataset, model):
    # create sampler & dataloader
    train_idx = dataset.train_idx.to(device)
    val_idx = dataset.val_idx.to(device)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    sampler = NeighborSampler(
        [10, 10, 10],  # fanout for [layer-0, layer-1, layer-2]
        prefetch_node_feats=["feat"],
        prefetch_labels=["label"],
    )
    use_uva = args.mode == "mixed"
    train_dataloader = DataLoader(
        g,
        train_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )

    val_dataloader = DataLoader(
        g,
        val_idx,
        sampler,
        device=device,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=use_uva,
    )
133
134

    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
135

136
137
138
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
139
140
141
142
143
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            train_dataloader
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]
144
145
146
147
148
149
150
            y_hat = model(blocks, x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        acc = evaluate(model, g, val_dataloader)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
151
152
153
154
155
156
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, total_loss / (it + 1), acc.item()
            )
        )

157

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
158
if __name__ == "__main__":
159
    parser = argparse.ArgumentParser()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
160
161
162
163
164
165
166
    parser.add_argument(
        "--mode",
        default="mixed",
        choices=["cpu", "mixed", "puregpu"],
        help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
        "'puregpu' for pure-GPU training.",
    )
167
168
    args = parser.parse_args()
    if not torch.cuda.is_available():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
169
170
        args.mode = "cpu"
    print(f"Training in {args.mode} mode.")
171

172
    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
173
174
    print("Loading data")
    dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
175
    g = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
176
177
    g = g.to("cuda" if args.mode == "puregpu" else "cpu")
    device = torch.device("cpu" if args.mode == "cpu" else "cuda")
178
179

    # create GraphSAGE model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
180
    in_size = g.ndata["feat"].shape[1]
181
182
    out_size = dataset.num_classes
    model = SAGE(in_size, 256, out_size).to(device)
183

184
    # model training
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
185
    print("Training...")
186
    train(args, device, g, dataset, model)
187

188
    # test the model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
189
    print("Testing...")
190
    acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
191
    print("Test Accuracy {:.4f}".format(acc.item()))