train.py 6.06 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import dgl
import dgl.function as fn
import dgl.nn as dglnn
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
8
import tqdm
9
from dgl import apply_each
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10
from dgl.dataloading import DataLoader, NeighborSampler
11
from ogb.nodeproppred import DglNodePropPredDataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12

13
14
15
16
17

class HeteroGAT(nn.Module):
    def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):
        super().__init__()
        self.layers = nn.ModuleList()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.layers.append(
            dglnn.HeteroGraphConv(
                {
                    etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
                    for etype in etypes
                }
            )
        )
        self.layers.append(
            dglnn.HeteroGraphConv(
                {
                    etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
                    for etype in etypes
                }
            )
        )
        self.layers.append(
            dglnn.HeteroGraphConv(
                {
                    etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
                    for etype in etypes
                }
            )
        )
42
        self.dropout = nn.Dropout(0.5)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
43
        self.linear = nn.Linear(hid_size, out_size)  # Should be HeteroLinear
44
45
46
47
48
49
50

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            # One thing is that h might return tensors with zero rows if the number of dst nodes
            # of one node type is 0.  x.view(x.shape[0], -1) wouldn't work in this case.
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
51
52
53
            h = apply_each(
                h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])
            )
54
55
56
            if l != len(self.layers) - 1:
                h = apply_each(h, F.relu)
                h = apply_each(h, self.dropout)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57
58
        return self.linear(h["paper"])

59
60
61
62
63

def evaluate(model, dataloader, desc):
    preds = []
    labels = []
    with torch.no_grad():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
64
65
66
67
68
        for input_nodes, output_nodes, blocks in tqdm.tqdm(
            dataloader, desc=desc
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]["paper"][:, 0]
69
70
71
72
73
74
75
76
            y_hat = model(blocks, x)
            preds.append(y_hat.cpu())
            labels.append(y.cpu())
        preds = torch.cat(preds, 0)
        labels = torch.cat(labels, 0)
        acc = MF.accuracy(preds, labels)
        return acc

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77

78
79
80
81
82
83
84
85
86
def train(train_loader, val_loader, test_loader, model):
    # loss function and optimizer
    loss_fcn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    # training loop
    for epoch in range(10):
        model.train()
        total_loss = 0
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
87
88
89
90
91
        for it, (input_nodes, output_nodes, blocks) in enumerate(
            tqdm.tqdm(train_dataloader, desc="Train")
        ):
            x = blocks[0].srcdata["feat"]
            y = blocks[-1].dstdata["label"]["paper"][:, 0]
92
93
94
95
96
97
98
            y_hat = model(blocks, x)
            loss = loss_fcn(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        model.eval()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
99
100
101
102
103
104
        val_acc = evaluate(model, val_dataloader, "Val. ")
        test_acc = evaluate(model, test_dataloader, "Test ")
        print(
            f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}"
        )

105

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
106
107
108
109
110
if __name__ == "__main__":
    print(
        f"Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
    # load and preprocess dataset
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
112
113
    print("Loading data")
    dataset = DglNodePropPredDataset("ogbn-mag")
114
    graph, labels = dataset[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
115
    graph.ndata["label"] = labels
116
117
118
    # add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
    graph = dgl.AddReverse()(graph)
    # precompute the author, topic, and institution features
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
119
120
121
122
123
124
125
126
127
    graph.update_all(
        fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="rev_writes"
    )
    graph.update_all(
        fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="has_topic"
    )
    graph.update_all(
        fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="affiliated_with"
    )
128
129
    # find train/val/test indexes
    split_idx = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
130
131
132
133
134
    train_idx, val_idx, test_idx = (
        split_idx["train"],
        split_idx["valid"],
        split_idx["test"],
    )
135
136
137
138
139
    train_idx = apply_each(train_idx, lambda x: x.to(device))
    val_idx = apply_each(val_idx, lambda x: x.to(device))
    test_idx = apply_each(test_idx, lambda x: x.to(device))

    # create RGAT model
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
140
    in_size = graph.ndata["feat"]["paper"].shape[1]
141
142
143
144
    out_size = dataset.num_classes
    model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device)

    # dataloader + model training + testing
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    train_sampler = NeighborSampler(
        [5, 5, 5],
        prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
        prefetch_labels={"paper": ["label"]},
    )
    val_sampler = NeighborSampler(
        [10, 10, 10],
        prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
        prefetch_labels={"paper": ["label"]},
    )
    train_dataloader = DataLoader(
        graph,
        train_idx,
        train_sampler,
        device=device,
        batch_size=1000,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        use_uva=torch.cuda.is_available(),
    )
    val_dataloader = DataLoader(
        graph,
        val_idx,
        val_sampler,
        device=device,
        batch_size=1000,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_uva=torch.cuda.is_available(),
    )
    test_dataloader = DataLoader(
        graph,
        test_idx,
        val_sampler,
        device=device,
        batch_size=1000,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_uva=torch.cuda.is_available(),
    )
188
189

    train(train_dataloader, val_dataloader, test_dataloader, model)