node_classification.py 4.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
10
11
import tqdm
import argparse
12

13
14
15
16
17
parser = argparse.ArgumentParser()
parser.add_argument('--pure-gpu', action='store_true',
                    help='Perform both sampling and training on GPU.')
args = parser.parse_args()

18
19
20
21
22
23
24
25
class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
26
27
        self.n_hidden = n_hidden
        self.n_classes = n_classes
28
29
30
31
32
33
34
35
36
37

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

38
    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
39
40
        feat = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
41
42
        dataloader = dgl.dataloading.NodeDataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
43
44
45
                batch_size=batch_size, shuffle=False, drop_last=False,
                num_workers=num_workers)

46
47
48
49
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
50
            y = torch.empty(
51
                g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
52
53
                device=buffer_device, pin_memory=True)
            feat = feat.to(device)
54
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
55
56
                # use an explicitly contiguous slice
                x = feat[input_nodes]
57
58
59
60
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
61
62
63
64
                # be design, our output nodes are contiguous so we can take
                # advantage of that here
                y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
            feat = y
65
66
        return y

67
68
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
69
graph.ndata['label'] = labels.squeeze()
70
71
72
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

73
device = 'cuda'
74
75
train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device)
76
77
78
test_idx = test_idx.to(device)

graph = graph.to('cuda' if args.pure_gpu else 'cpu')
79

80
81
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
82
83

sampler = dgl.dataloading.NeighborSampler(
84
        [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
85
train_dataloader = dgl.dataloading.DataLoader(
86
        graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
87
        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
88
89
valid_dataloader = dgl.dataloading.NodeDataLoader(
        graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
90
        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
91
92
93

durations = []
for _ in range(10):
94
    model.train()
95
    t0 = time.time()
96
    for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
97
        x = blocks[0].srcdata['feat']
98
        y = blocks[-1].dstdata['label']
99
100
101
102
103
104
105
106
107
108
109
110
        y_hat = model(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if it % 20 == 0:
            acc = MF.accuracy(y_hat, y)
            mem = torch.cuda.max_memory_allocated() / 1000000
            print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
    tt = time.time()
    print(tt - t0)
    durations.append(tt - t0)
111
112
113
114
115
116
117
118
119
120
121
122

    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'])
            y_hats.append(model(blocks, x))
    acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
    print('Validation acc:', acc.item())

123
print(np.mean(durations[4:]), np.std(durations[4:]))
124
125
126
127

# Test accuracy and offline inference of all nodes
model.eval()
with torch.no_grad():
128
129
    pred = model.inference(graph, device, 4096, 0, 'cpu')
    pred = pred[test_idx].to(device)
130
131
    label = graph.ndata['label'][test_idx]
    acc = MF.accuracy(pred, label)
132
    print('Test acc:', acc.item())