rgat.py 4.93 KB
Newer Older
1
2
3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.function as fn
import dgl.nn as dglnn
8
from dgl import apply_each
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm

class HeteroGAT(nn.Module):
    def __init__(self, etypes, in_feats, n_hidden, n_classes, n_heads=4):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(in_feats, n_hidden // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(n_hidden, n_hidden // n_heads, n_heads)
            for etype in etypes}))
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(n_hidden, n_classes)   # Should be HeteroLinear

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            # One thing is that h might return tensors with zero rows if the number of dst nodes
            # of one node type is 0.  x.view(x.shape[0], -1) wouldn't work in this case.
36
            h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2]))
37
            if l != len(self.layers) - 1:
38
39
                h = apply_each(h, F.relu)
                h = apply_each(h, self.dropout)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        return self.linear(h['paper'])

dataset = DglNodePropPredDataset('ogbn-mag')

graph, labels = dataset[0]
graph.ndata['label'] = labels
# Preprocess: add reverse edges in "cites" relation, and add reverse edge types for the
# rest.
graph = dgl.AddReverse()(graph)
# Preprocess: precompute the author, topic, and institution features
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with')

model = HeteroGAT(graph.etypes, graph.ndata['feat']['paper'].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
59
60
61
train_idx = apply_each(train_idx, lambda x: x.to('cuda'))
valid_idx = apply_each(valid_idx, lambda x: x.to('cuda'))
test_idx = apply_each(test_idx, lambda x: x.to('cuda'))
62
63
64
65
66
67
68

train_sampler = dgl.dataloading.NeighborSampler(
        [5, 5, 5],
        prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
        prefetch_labels={'paper': ['label']})
valid_sampler = dgl.dataloading.NeighborSampler(
        [10, 10, 10],   # Slightly more
69
        prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
70
        prefetch_labels={'paper': ['label']})
71
train_dataloader = dgl.dataloading.DataLoader(
72
73
        graph, train_idx, train_sampler,
        device='cuda', batch_size=1000, shuffle=True,
74
75
        drop_last=False, num_workers=0, use_uva=True)
valid_dataloader = dgl.dataloading.DataLoader(
76
77
        graph, valid_idx, valid_sampler,
        device='cuda', batch_size=1000, shuffle=False,
78
79
        drop_last=False, num_workers=0, use_uva=True)
test_dataloader = dgl.dataloading.DataLoader(
80
81
        graph, test_idx, valid_sampler,
        device='cuda', batch_size=1000, shuffle=False,
82
        drop_last=False, num_workers=0, use_uva=True)
83
84
85
86
87
88
89
90
91

def evaluate(model, dataloader):
    preds = []
    labels = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']['paper'][:, 0]
            y_hat = model(blocks, x)
92
93
            preds.append(y_hat.cpu())
            labels.append(y.cpu())
94
95
96
97
        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = MF.accuracy(preds, labels)
        return acc
98
99
100

durations = []
for _ in range(10):
101
    model.train()
102
    t0 = time.time()
103
    for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        x = blocks[0].srcdata['feat']
        y = blocks[-1].dstdata['label']['paper'][:, 0]
        y_hat = model(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if it % 20 == 0:
            acc = MF.accuracy(y_hat, y)
            mem = torch.cuda.max_memory_allocated() / 1000000
            print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
    tt = time.time()
    print(tt - t0)
    durations.append(tt - t0)
118
119
120
121
122

    model.eval()
    valid_acc = evaluate(model, valid_dataloader)
    test_acc = evaluate(model, test_dataloader)
    print('Validation acc:', valid_acc, 'Test acc:', test_acc)
123
print(np.mean(durations[4:]), np.std(durations[4:]))