gcn.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import math
import time

import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
13
14
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GCN
15
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
16

17
18
device = None
in_feats, n_classes = None, None
19
epsilon = 1 - math.log(2)
20
21


22
23
24
def gen_model(args):
    if args.use_labels:
        model = GCN(
25
26
27
28
29
30
31
            in_feats + n_classes,
            args.n_hidden,
            n_classes,
            args.n_layers,
            F.relu,
            args.dropout,
            args.use_linear,
32
33
        )
    else:
34
35
36
37
38
39
40
41
42
        model = GCN(
            in_feats,
            args.n_hidden,
            n_classes,
            args.n_layers,
            F.relu,
            args.dropout,
            args.use_linear,
        )
43
    return model
44
45
46


def cross_entropy(x, labels):
47
    y = F.cross_entropy(x, labels[:, 0], reduction="none")
48
    y = th.log(epsilon + y) - math.log(epsilon)
49
50
51
    return th.mean(y)


52
def compute_acc(pred, labels, evaluator):
53
54
55
    return evaluator.eval(
        {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
    )["acc"]
56
57


58
59
def add_labels(feat, labels, idx):
    onehot = th.zeros([feat.shape[0], n_classes]).to(device)
60
    onehot[idx, labels[idx, 0]] = 1
61
62
63
    return th.cat([feat, onehot], dim=-1)


64
65
66
67
68
69
def adjust_learning_rate(optimizer, lr, epoch):
    if epoch <= 50:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr * epoch / 50


70
def train(model, graph, labels, train_idx, optimizer, use_labels):
71
72
73
    model.train()

    feat = graph.ndata["feat"]
74
75
76
77
78
79
80
81
82
83
84
85

    if use_labels:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate

        train_labels_idx = train_idx[mask]
        train_pred_idx = train_idx[~mask]

        feat = add_labels(feat, labels, train_labels_idx)
    else:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate
86

87
        train_pred_idx = train_idx[mask]
88
89
90

    optimizer.zero_grad()
    pred = model(graph, feat)
91
    loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])
92
93
94
95
96
97
98
    loss.backward()
    optimizer.step()

    return loss, pred


@th.no_grad()
99
100
101
def evaluate(
    model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator
):
102
103
104
    model.eval()

    feat = graph.ndata["feat"]
105
106
107
108

    if use_labels:
        feat = add_labels(feat, labels, train_idx)

109
    pred = model(graph, feat)
110
    train_loss = cross_entropy(pred[train_idx], labels[train_idx])
111
    val_loss = cross_entropy(pred[val_idx], labels[val_idx])
112
    test_loss = cross_entropy(pred[test_idx], labels[test_idx])
113
114

    return (
115
116
117
        compute_acc(pred[train_idx], labels[train_idx], evaluator),
        compute_acc(pred[val_idx], labels[val_idx], evaluator),
        compute_acc(pred[test_idx], labels[test_idx], evaluator),
118
        train_loss,
119
        val_loss,
120
        test_loss,
121
122
123
    )


124
125
126
def run(
    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
127
128
    # define model and optimizer
    model = gen_model(args)
129
    model = model.to(device)
130

131
132
133
    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.wd
    )
134
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
135
136
137
138
139
140
        optimizer,
        mode="min",
        factor=0.5,
        patience=100,
        verbose=True,
        min_lr=1e-3,
141
    )
142

143
    # training loop
144
    total_time = 0
145
    best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")
146
147
148

    accs, train_accs, val_accs, test_accs = [], [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []
149
150
151
152

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

153
        adjust_learning_rate(optimizer, args.lr, epoch)
154

155
156
157
        loss, pred = train(
            model, graph, labels, train_idx, optimizer, args.use_labels
        )
158
        acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        (
            train_acc,
            val_acc,
            test_acc,
            train_loss,
            val_loss,
            test_loss,
        ) = evaluate(
            model,
            graph,
            labels,
            train_idx,
            val_idx,
            test_idx,
            args.use_labels,
            evaluator,
176
        )
177

178
        lr_scheduler.step(loss)
179

espylapiza's avatar
espylapiza committed
180
181
        toc = time.time()
        total_time += toc - tic
182
183

        if val_loss < best_val_loss:
184
185
            best_val_loss = val_loss
            best_val_acc = val_acc
186
            final_test_acc = test_acc
187
188
189

        if epoch % args.log_every == 0:
            print(
190
                f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
191
                f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
192
                f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
193
                f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
194
195
            )

196
        for l, e in zip(
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            [
                accs,
                train_accs,
                val_accs,
                test_accs,
                losses,
                train_losses,
                val_losses,
                test_losses,
            ],
            [
                acc,
                train_acc,
                val_acc,
                test_acc,
                loss,
                train_loss,
                val_loss,
                test_loss,
            ],
217
218
219
220
        ):
            l.append(e)

    print("*" * 50)
221
222
    print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
    print("*" * 50)
223
224
225
226
227
228
229

    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
230
231
232
233
        for y, label in zip(
            [accs, train_accs, val_accs, test_accs],
            ["acc", "train acc", "val acc", "test acc"],
        ):
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_acc_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
250
251
            [losses, train_losses, val_losses, test_losses],
            ["loss", "train loss", "val loss", "test loss"],
252
253
254
255
256
257
258
259
260
261
262
        ):
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_loss_{n_running}.png")
263

264
    return best_val_acc, final_test_acc
265
266


267
268
def count_parameters(args):
    model = gen_model(args)
269
270
271
    return sum(
        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]
    )
272
273


274
def main():
275
276
    global device, in_feats, n_classes

277
278
279
280
281
282
283
284
285
    argparser = argparse.ArgumentParser(
        "GCN on OGBN-Arxiv",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    argparser.add_argument(
        "--cpu",
        action="store_true",
        help="CPU mode. This option overrides --gpu.",
    )
286
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
287
    argparser.add_argument(
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        "--n-runs", type=int, default=10, help="running times"
    )
    argparser.add_argument(
        "--n-epochs", type=int, default=1000, help="number of epochs"
    )
    argparser.add_argument(
        "--use-labels",
        action="store_true",
        help="Use labels in the training set as input features.",
    )
    argparser.add_argument(
        "--use-linear", action="store_true", help="Use linear layer."
    )
    argparser.add_argument(
        "--lr", type=float, default=0.005, help="learning rate"
    )
    argparser.add_argument(
        "--n-layers", type=int, default=3, help="number of layers"
    )
    argparser.add_argument(
        "--n-hidden", type=int, default=256, help="number of hidden units"
    )
    argparser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout rate"
312
    )
313
    argparser.add_argument("--wd", type=float, default=0, help="weight decay")
314
315
316
317
318
319
    argparser.add_argument(
        "--log-every", type=int, default=20, help="log every LOG_EVERY epochs"
    )
    argparser.add_argument(
        "--plot-curves", action="store_true", help="plot learning curves"
    )
320
321
322
323
324
325
326
327
328
    args = argparser.parse_args()

    if args.cpu:
        device = th.device("cpu")
    else:
        device = th.device("cuda:%d" % args.gpu)

    # load data
    data = DglNodePropPredDataset(name="ogbn-arxiv")
329
330
    evaluator = Evaluator(name="ogbn-arxiv")

331
    splitted_idx = data.get_idx_split()
332
333
334
335
336
    train_idx, val_idx, test_idx = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )
337
338
339
340
341
342
343
344
345
346
347
348
349
    graph, labels = data[0]

    # add reverse edges
    srcs, dsts = graph.all_edges()
    graph.add_edges(dsts, srcs)

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    in_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
350
    graph.create_formats_()
351
352
353
354
355
356
357
358
359
360
361
362

    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    test_idx = test_idx.to(device)
    labels = labels.to(device)
    graph = graph.to(device)

    # run
    val_accs = []
    test_accs = []

    for i in range(args.n_runs):
363
364
365
        val_acc, test_acc = run(
            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i
        )
366
367
368
369
370
371
372
373
        val_accs.append(val_acc)
        test_accs.append(test_acc)

    print(f"Runned {args.n_runs} times")
    print("Val Accs:", val_accs)
    print("Test Accs:", test_accs)
    print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
    print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
374
    print(f"Number of params: {count_parameters(args)}")
375
376
377
378


if __name__ == "__main__":
    main()