node_classification.py 4.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
10
11
import tqdm
import argparse
12
13
14
15
16
17
18
19
20

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(0.5)
21
22
        self.n_hidden = n_hidden
        self.n_classes = n_classes
23
24
25
26
27
28
29
30
31
32

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    def inference(self, g, device, batch_size, num_workers, buffer_device=None):
        # The difference between this inference function and the one in the official
        # example is that the intermediate results can also benefit from prefetching.
        g.ndata['h'] = g.ndata['feat']
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
        dataloader = dgl.dataloading.NodeDataLoader(
                g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
                batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
                persistent_workers=(num_workers > 0))
        if buffer_device is None:
            buffer_device = device

        for l, layer in enumerate(self.layers):
            y = torch.zeros(
                g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
                device=buffer_device)
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                x = blocks[0].srcdata['h']
                h = layer(blocks[0], x)
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                y[output_nodes] = h.to(buffer_device)
            g.ndata['h'] = y
        return y

59
60
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
61
graph.ndata['label'] = labels.squeeze()
62
63
64
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

65
device = 'cuda'
66
67
train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device)
68

69
70
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
71
72

sampler = dgl.dataloading.NeighborSampler(
73
        [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
74
train_dataloader = dgl.dataloading.DataLoader(
75
        graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
76
        drop_last=False, num_workers=0, use_uva=True)
77
78
valid_dataloader = dgl.dataloading.NodeDataLoader(
        graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
79
        drop_last=False, num_workers=0, use_uva=True)
80
81
82

durations = []
for _ in range(10):
83
    model.train()
84
    t0 = time.time()
85
    for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
86
        x = blocks[0].srcdata['feat']
87
        y = blocks[-1].dstdata['label']
88
89
90
91
92
93
94
95
96
97
98
99
        y_hat = model(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if it % 20 == 0:
            acc = MF.accuracy(y_hat, y)
            mem = torch.cuda.max_memory_allocated() / 1000000
            print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
    tt = time.time()
    print(tt - t0)
    durations.append(tt - t0)
100
101
102
103
104
105
106
107
108
109
110
111

    model.eval()
    ys = []
    y_hats = []
    for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
        with torch.no_grad():
            x = blocks[0].srcdata['feat']
            ys.append(blocks[-1].dstdata['label'])
            y_hats.append(model(blocks, x))
    acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
    print('Validation acc:', acc.item())

112
print(np.mean(durations[4:]), np.std(durations[4:]))
113
114
115
116

# Test accuracy and offline inference of all nodes
model.eval()
with torch.no_grad():
117
118
119
120
    pred = model.inference(graph, device, 4096, 12, graph.device)
    pred = pred[test_idx]
    label = graph.ndata['label'][test_idx]
    acc = MF.accuracy(pred, label)
121
    print('Test acc:', acc.item())