train.py 5.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.function as fn
import dgl.nn as dglnn
from dgl.dataloading import NeighborSampler, DataLoader
from dgl import apply_each
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm

class HeteroGAT(nn.Module):
    def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.layers.append(dglnn.HeteroGraphConv({
            etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
            for etype in etypes}))
        self.dropout = nn.Dropout(0.5)
        self.linear = nn.Linear(hid_size, out_size)   # Should be HeteroLinear

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            # One thing is that h might return tensors with zero rows if the number of dst nodes
            # of one node type is 0.  x.view(x.shape[0], -1) wouldn't work in this case.
            h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2]))
            if l != len(self.layers) - 1:
                h = apply_each(h, F.relu)
                h = apply_each(h, self.dropout)
        return self.linear(h['paper'])

def evaluate(model, dataloader, desc):
    preds = []
    labels = []
    with torch.no_grad():
        for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc=desc):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']['paper'][:, 0]
            y_hat = model(blocks, x)
            preds.append(y_hat.cpu())
            labels.append(y.cpu())
        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = MF.accuracy(preds, labels)
        return acc

def train(train_loader, val_loader, test_loader, model):
    # loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    # training loop
    for epoch in range(10):
        model.train()
        total_loss = 0
        for it, (input_nodes, output_nodes, blocks) in enumerate(tqdm.tqdm(train_dataloader, desc="Train")):
            x = blocks[0].srcdata['feat']
            y = blocks[-1].dstdata['label']['paper'][:, 0]
            y_hat = model(blocks, x)
            loss = loss_fcn(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        model.eval()
        val_acc = evaluate(model, val_dataloader, 'Val. ')
        test_acc = evaluate(model, test_dataloader, 'Test ')
        print(f'Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}')

if __name__ == '__main__':
    print(f'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load and preprocess dataset
    print('Loading data')
    dataset = DglNodePropPredDataset('ogbn-mag')
    graph, labels = dataset[0]
    graph.ndata['label'] = labels
    # add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
    graph = dgl.AddReverse()(graph)
    # precompute the author, topic, and institution features
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes')
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic')
    graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with')
    # find train/val/test indexes
    split_idx = dataset.get_idx_split()
    train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
    train_idx = apply_each(train_idx, lambda x: x.to(device))
    val_idx = apply_each(val_idx, lambda x: x.to(device))
    test_idx = apply_each(test_idx, lambda x: x.to(device))

    # create RGAT model
    in_size = graph.ndata['feat']['paper'].shape[1]
    out_size = dataset.num_classes
    model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device)

    # dataloader + model training + testing
    train_sampler = NeighborSampler([5, 5, 5],
                                    prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
                                    prefetch_labels={'paper': ['label']})
    val_sampler = NeighborSampler([10, 10, 10],
                                  prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
                                  prefetch_labels={'paper': ['label']})
    train_dataloader = DataLoader(graph, train_idx, train_sampler,
                                  device=device, batch_size=1000, shuffle=True,
                                  drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
    val_dataloader = DataLoader(graph, val_idx, val_sampler,
                                device=device, batch_size=1000, shuffle=False,
                                drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
    test_dataloader = DataLoader(graph, test_idx, val_sampler,
                                 device=device, batch_size=1000, shuffle=False,
                                 drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())

    train(train_dataloader, val_dataloader, test_dataloader, model)