train.py 10 KB
Newer Older
1
2
3
#!/usr/bin/env python
# coding: utf-8

4
5
6
import argparse
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
7
8
9
10
import dgl
import dgl.function as fn
import dgl.nn as dglnn

11
import numpy as np
12
13
import ogb
import torch
14
15
import torch.nn as nn
import torch.nn.functional as F
16
import tqdm
17
18
from ogb.lsc import MAG240MDataset, MAG240MEvaluator

19
20

class RGAT(nn.Module):
21
22
23
24
25
26
27
28
29
30
31
    def __init__(
        self,
        in_channels,
        out_channels,
        hidden_channels,
        num_etypes,
        num_layers,
        num_heads,
        dropout,
        pred_ntype,
    ):
32
33
34
35
        super().__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.skips = nn.ModuleList()
36
37
38
39
40
41
42
43
44
45
46
47
48
49

        self.convs.append(
            nn.ModuleList(
                [
                    dglnn.GATConv(
                        in_channels,
                        hidden_channels // num_heads,
                        num_heads,
                        allow_zero_in_degree=True,
                    )
                    for _ in range(num_etypes)
                ]
            )
        )
50
51
52
        self.norms.append(nn.BatchNorm1d(hidden_channels))
        self.skips.append(nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
53
54
55
56
57
58
59
60
61
62
63
64
65
            self.convs.append(
                nn.ModuleList(
                    [
                        dglnn.GATConv(
                            hidden_channels,
                            hidden_channels // num_heads,
                            num_heads,
                            allow_zero_in_degree=True,
                        )
                        for _ in range(num_etypes)
                    ]
                )
            )
66
67
            self.norms.append(nn.BatchNorm1d(hidden_channels))
            self.skips.append(nn.Linear(hidden_channels, hidden_channels))
68

69
70
71
72
73
        self.mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
74
            nn.Linear(hidden_channels, out_channels),
75
76
        )
        self.dropout = nn.Dropout(dropout)
77

78
79
80
        self.hidden_channels = hidden_channels
        self.pred_ntype = pred_ntype
        self.num_etypes = num_etypes
81

82
83
84
    def forward(self, mfgs, x):
        for i in range(len(mfgs)):
            mfg = mfgs[i]
85
            x_dst = x[: mfg.num_dst_nodes()]
86
87
88
89
90
            n_src = mfg.num_src_nodes()
            n_dst = mfg.num_dst_nodes()
            mfg = dgl.block_to_graph(mfg)
            x_skip = self.skips[i](x_dst)
            for j in range(self.num_etypes):
91
92
93
94
95
96
                subg = mfg.edge_subgraph(
                    mfg.edata["etype"] == j, relabel_nodes=False
                )
                x_skip += self.convs[i][j](subg, (x, x_dst)).view(
                    -1, self.hidden_channels
                )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            x = self.norms[i](x_skip)
            x = F.elu(x)
            x = self.dropout(x)
        return self.mlp(x)


class ExternalNodeCollator(dgl.dataloading.NodeCollator):
    def __init__(self, g, idx, sampler, offset, feats, label):
        super().__init__(g, idx, sampler)
        self.offset = offset
        self.feats = feats
        self.label = label

    def collate(self, items):
        input_nodes, output_nodes, mfgs = super().collate(items)
        # Copy input features
113
114
115
116
        mfgs[0].srcdata["x"] = torch.FloatTensor(self.feats[input_nodes])
        mfgs[-1].dstdata["y"] = torch.LongTensor(
            self.label[output_nodes - self.offset]
        )
117
118
        return input_nodes, output_nodes, mfgs

119

120
def train(args, dataset, g, feats, paper_offset):
121
122
123
    print("Loading masks and labels")
    train_idx = torch.LongTensor(dataset.get_idx_split("train")) + paper_offset
    valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
124
125
    label = dataset.paper_label

126
    print("Initializing dataloader...")
127
    sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
128
129
130
131
132
133
    train_collator = ExternalNodeCollator(
        g, train_idx, sampler, paper_offset, feats, label
    )
    valid_collator = ExternalNodeCollator(
        g, valid_idx, sampler, paper_offset, feats, label
    )
134
135
136
137
138
139
    train_dataloader = torch.utils.data.DataLoader(
        train_collator.dataset,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        collate_fn=train_collator.collate,
140
        num_workers=4,
141
142
143
144
145
146
147
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        collate_fn=valid_collator.collate,
148
        num_workers=2,
149
150
    )

151
152
153
154
155
156
157
158
159
160
161
    print("Initializing model...")
    model = RGAT(
        dataset.num_paper_features,
        dataset.num_classes,
        1024,
        5,
        2,
        4,
        0.5,
        "paper",
    ).cuda()
162
163
164
165
166
167
168
169
170
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)

    best_acc = 0

    for _ in range(args.epochs):
        model.train()
        with tqdm.tqdm(train_dataloader) as tq:
            for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
171
172
173
                mfgs = [g.to("cuda") for g in mfgs]
                x = mfgs[0].srcdata["x"]
                y = mfgs[-1].dstdata["y"]
174
175
176
177
178
179
                y_hat = model(mfgs, x)
                loss = F.cross_entropy(y_hat, y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                acc = (y_hat.argmax(1) == y).float().mean()
180
181
182
183
                tq.set_postfix(
                    {"loss": "%.4f" % loss.item(), "acc": "%.4f" % acc.item()},
                    refresh=False,
                )
184
185
186

        model.eval()
        correct = total = 0
187
188
189
        for i, (input_nodes, output_nodes, mfgs) in enumerate(
            tqdm.tqdm(valid_dataloader)
        ):
190
            with torch.no_grad():
191
192
193
                mfgs = [g.to("cuda") for g in mfgs]
                x = mfgs[0].srcdata["x"]
                y = mfgs[-1].dstdata["y"]
194
195
196
197
                y_hat = model(mfgs, x)
                correct += (y_hat.argmax(1) == y).sum().item()
                total += y_hat.shape[0]
        acc = correct / total
198
        print("Validation accuracy:", acc)
199
200
201
202
203

        sched.step()

        if best_acc < acc:
            best_acc = acc
204
            print("Updating best model...")
205
206
            torch.save(model.state_dict(), args.model_path)

207

208
def test(args, dataset, g, feats, paper_offset):
209
210
211
    print("Loading masks and labels...")
    valid_idx = torch.LongTensor(dataset.get_idx_split("valid")) + paper_offset
    test_idx = torch.LongTensor(dataset.get_idx_split("test")) + paper_offset
212
213
    label = dataset.paper_label

214
    print("Initializing data loader...")
215
    sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
216
217
218
    valid_collator = ExternalNodeCollator(
        g, valid_idx, sampler, paper_offset, feats, label
    )
219
220
221
222
223
224
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=valid_collator.collate,
225
226
227
228
        num_workers=2,
    )
    test_collator = ExternalNodeCollator(
        g, test_idx, sampler, paper_offset, feats, label
229
230
231
232
233
234
235
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=test_collator.collate,
236
        num_workers=4,
237
238
    )

239
240
241
242
243
244
245
246
247
248
249
    print("Loading model...")
    model = RGAT(
        dataset.num_paper_features,
        dataset.num_classes,
        1024,
        5,
        2,
        4,
        0.5,
        "paper",
    ).cuda()
250
251
252
253
    model.load_state_dict(torch.load(args.model_path))

    model.eval()
    correct = total = 0
254
255
256
    for i, (input_nodes, output_nodes, mfgs) in enumerate(
        tqdm.tqdm(valid_dataloader)
    ):
257
        with torch.no_grad():
258
259
260
            mfgs = [g.to("cuda") for g in mfgs]
            x = mfgs[0].srcdata["x"]
            y = mfgs[-1].dstdata["y"]
261
262
263
264
            y_hat = model(mfgs, x)
            correct += (y_hat.argmax(1) == y).sum().item()
            total += y_hat.shape[0]
    acc = correct / total
265
    print("Validation accuracy:", acc)
266
267
    evaluator = MAG240MEvaluator()
    y_preds = []
268
269
270
    for i, (input_nodes, output_nodes, mfgs) in enumerate(
        tqdm.tqdm(test_dataloader)
    ):
271
        with torch.no_grad():
272
273
274
            mfgs = [g.to("cuda") for g in mfgs]
            x = mfgs[0].srcdata["x"]
            y = mfgs[-1].dstdata["y"]
275
276
            y_hat = model(mfgs, x)
            y_preds.append(y_hat.argmax(1).cpu())
277
278
279
    evaluator.save_test_submission(
        {"y_pred": torch.cat(y_preds)}, args.submission_path
    )
280
281


282
if __name__ == "__main__":
283
    parser = argparse.ArgumentParser()
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    parser.add_argument(
        "--rootdir",
        type=str,
        default=".",
        help="Directory to download the OGB dataset.",
    )
    parser.add_argument(
        "--graph-path",
        type=str,
        default="./graph.dgl",
        help="Path to the graph.",
    )
    parser.add_argument(
        "--full-feature-path",
        type=str,
        default="./full.npy",
        help="Path to the features of all nodes.",
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs."
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="./model.pt",
        help="Path to store the best model.",
    )
    parser.add_argument(
        "--submission-path",
        type=str,
        default="./results",
        help="Submission directory.",
    )
317
318
319
320
    args = parser.parse_args()

    dataset = MAG240MDataset(root=args.rootdir)

321
    print("Loading graph")
322
    (g,), _ = dgl.load_graphs(args.graph_path)
323
    g = g.formats(["csc"])
324

325
    print("Loading features")
326
327
328
    paper_offset = dataset.num_authors + dataset.num_institutions
    num_nodes = paper_offset + dataset.num_papers
    num_features = dataset.num_paper_features
329
330
331
332
333
334
    feats = np.memmap(
        args.full_feature_path,
        mode="r",
        dtype="float16",
        shape=(num_nodes, num_features),
    )
335
336
337
338

    if args.epochs != 0:
        train(args, dataset, g, feats, paper_offset)
    test(args, dataset, g, feats, paper_offset)