gat.py 14.4 KB
Newer Older
1
2
3
4
5
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import math
6
7
import os
import random
8
9
import time

10
import dgl
11
import numpy as np
12
import torch
13
14
15
16
17
18
19
20
21
22
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

from models import GAT

epsilon = 1 - math.log(2)

23
24
25
26
device = None

dataset = "ogbn-arxiv"
n_node_feats, n_classes = 0, 0
27
28


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    dgl.random.seed(seed)


def load_data(dataset):
    global n_node_feats, n_classes

    data = DglNodePropPredDataset(name=dataset)
    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]

    n_node_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()

    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph):
    global n_node_feats

    # make bidirected
    feat = graph.ndata["feat"]
    graph = dgl.to_bidirected(graph)
    graph.ndata["feat"] = feat

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    graph.create_formats_()

    return graph


def gen_model(args):
75
    if args.use_labels:
76
        n_node_feats_ = n_node_feats + n_classes
77
    else:
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        n_node_feats_ = n_node_feats

    model = GAT(
        n_node_feats_,
        n_classes,
        n_hidden=args.n_hidden,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        activation=F.relu,
        dropout=args.dropout,
        input_drop=args.input_drop,
        attn_drop=args.attn_drop,
        edge_drop=args.edge_drop,
        use_attn_dst=not args.no_attn_dst,
        use_symmetric_norm=args.use_norm,
    )
94
95
96
97

    return model


98
def custom_loss_function(x, labels):
99
    y = F.cross_entropy(x, labels[:, 0], reduction="none")
100
101
    y = torch.log(epsilon + y) - math.log(epsilon)
    return torch.mean(y)
102
103
104


def add_labels(feat, labels, idx):
105
    onehot = torch.zeros([feat.shape[0], n_classes], device=device)
106
    onehot[idx, labels[idx, 0]] = 1
107
    return torch.cat([feat, onehot], dim=-1)
108
109
110
111
112
113
114
115


def adjust_learning_rate(optimizer, lr, epoch):
    if epoch <= 50:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr * epoch / 50


116
def train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator):
117
118
119
120
    model.train()

    feat = graph.ndata["feat"]

121
122
    if args.use_labels:
        mask = torch.rand(train_idx.shape) < args.mask_rate
123
124
125
126
127
128

        train_labels_idx = train_idx[mask]
        train_pred_idx = train_idx[~mask]

        feat = add_labels(feat, labels, train_labels_idx)
    else:
129
        mask = torch.rand(train_idx.shape) < args.mask_rate
130
131
132
133
134

        train_pred_idx = train_idx[mask]

    optimizer.zero_grad()
    pred = model(graph, feat)
135
136
137
138
139
140
141
142
143
144

    if args.n_label_iters > 0:
        unlabel_idx = torch.cat([train_pred_idx, val_idx, test_idx])
        for _ in range(args.n_label_iters):
            pred = pred.detach()
            torch.cuda.empty_cache()
            feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1)
            pred = model(graph, feat)

    loss = custom_loss_function(pred[train_pred_idx], labels[train_pred_idx])
145
146
147
    loss.backward()
    optimizer.step()

148
    return evaluator(pred[train_idx], labels[train_idx]), loss.item()
149
150


151
152
@torch.no_grad()
def evaluate(args, model, graph, labels, train_idx, val_idx, test_idx, evaluator):
153
154
155
156
    model.eval()

    feat = graph.ndata["feat"]

157
    if args.use_labels:
158
159
160
        feat = add_labels(feat, labels, train_idx)

    pred = model(graph, feat)
161
162
163
164
165
166
167
168
169
170

    if args.n_label_iters > 0:
        unlabel_idx = torch.cat([val_idx, test_idx])
        for _ in range(args.n_label_iters):
            feat[unlabel_idx, -n_classes:] = F.softmax(pred[unlabel_idx], dim=-1)
            pred = model(graph, feat)

    train_loss = custom_loss_function(pred[train_idx], labels[train_idx])
    val_loss = custom_loss_function(pred[val_idx], labels[val_idx])
    test_loss = custom_loss_function(pred[test_idx], labels[test_idx])
171
172

    return (
173
174
175
        evaluator(pred[train_idx], labels[train_idx]),
        evaluator(pred[val_idx], labels[val_idx]),
        evaluator(pred[test_idx], labels[test_idx]),
176
177
178
        train_loss,
        val_loss,
        test_loss,
179
        pred,
180
181
182
183
    )


def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running):
184
185
186
    evaluator_wrapper = lambda pred, labels: evaluator.eval(
        {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
    )["acc"]
187

188
189
    # define model and optimizer
    model = gen_model(args).to(device)
190
191
192
193
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # training loop
    total_time = 0
194
195
    best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")
    final_pred = None
196
197
198
199
200
201
202
203
204

    accs, train_accs, val_accs, test_accs = [], [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

        adjust_learning_rate(optimizer, args.lr, epoch)

205
        acc, loss = train(args, model, graph, labels, train_idx, val_idx, test_idx, optimizer, evaluator_wrapper)
206

207
208
        train_acc, val_acc, test_acc, train_loss, val_loss, test_loss, pred = evaluate(
            args, model, graph, labels, train_idx, val_idx, test_idx, evaluator_wrapper
209
210
211
212
213
214
215
216
        )

        toc = time.time()
        total_time += toc - tic

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
217
218
            final_test_acc = test_acc
            final_pred = pred
219

220
        if epoch == args.n_epochs or epoch % args.log_every == 0:
221
            print(
222
223
                f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
                f"Loss: {loss:.4f}, Acc: {acc:.4f}\n"
224
                f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
225
                f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
226
227
228
229
            )

        for l, e in zip(
            [accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses],
230
            [acc, train_acc, val_acc, test_acc, loss, train_loss, val_loss, test_loss],
231
232
233
234
        ):
            l.append(e)

    print("*" * 50)
235
236
    print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
    print("*" * 50)
237

238
    # plot learning curves
239
240
241
242
243
244
245
    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]):
246
            plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_acc_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
        ):
264
            plt.plot(range(args.n_epochs), y, label=label, linewidth=1)
265
266
267
268
269
270
271
272
273
274
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_loss_{n_running}.png")

275
276
277
278
279
    if args.save_pred:
        os.makedirs("./output", exist_ok=True)
        torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")

    return best_val_acc, final_test_acc
280
281
282
283


def count_parameters(args):
    model = gen_model(args)
284
    return sum([p.numel() for p in model.parameters() if p.requires_grad])
285
286
287


def main():
288
    global device, n_node_feats, n_classes, epsilon
289

290
291
292
    argparser = argparse.ArgumentParser(
        "GAT implementation on ogbn-arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
293
294
    argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
295
296
297
    argparser.add_argument("--seed", type=int, default=0, help="seed")
    argparser.add_argument("--n-runs", type=int, default=10, help="running times")
    argparser.add_argument("--n-epochs", type=int, default=2000, help="number of epochs")
298
299
300
    argparser.add_argument(
        "--use-labels", action="store_true", help="Use labels in the training set as input features."
    )
301
302
303
    argparser.add_argument("--n-label-iters", type=int, default=0, help="number of label iterations")
    argparser.add_argument("--mask-rate", type=float, default=0.5, help="mask rate")
    argparser.add_argument("--no-attn-dst", action="store_true", help="Don't use attn_dst.")
304
    argparser.add_argument("--use-norm", action="store_true", help="Use symmetrically normalized adjacency matrix.")
305
306
307
308
309
310
311
312
313
314
315
316
    argparser.add_argument("--lr", type=float, default=0.002, help="learning rate")
    argparser.add_argument("--n-layers", type=int, default=3, help="number of layers")
    argparser.add_argument("--n-heads", type=int, default=3, help="number of heads")
    argparser.add_argument("--n-hidden", type=int, default=250, help="number of hidden units")
    argparser.add_argument("--dropout", type=float, default=0.75, help="dropout rate")
    argparser.add_argument("--input-drop", type=float, default=0.1, help="input drop rate")
    argparser.add_argument("--attn-drop", type=float, default=0.0, help="attention drop rate")
    argparser.add_argument("--edge-drop", type=float, default=0.0, help="edge drop rate")
    argparser.add_argument("--wd", type=float, default=0, help="weight decay")
    argparser.add_argument("--log-every", type=int, default=20, help="log every LOG_EVERY epochs")
    argparser.add_argument("--plot-curves", action="store_true", help="plot learning curves")
    argparser.add_argument("--save-pred", action="store_true", help="save final predictions")
317
318
    args = argparser.parse_args()

319
320
321
    if not args.use_labels and args.n_label_iters > 0:
        raise ValueError("'--use-labels' must be enabled when n_label_iters > 0")

322
    if args.cpu:
323
        device = torch.device("cpu")
324
    else:
325
        device = torch.device(f"cuda:{args.gpu}")
326

327
328
329
    # load data & preprocess
    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
    graph = preprocess(graph)
330

331
332
333
    graph, labels, train_idx, val_idx, test_idx = map(
        lambda x: x.to(device), (graph, labels, train_idx, val_idx, test_idx)
    )
334
335

    # run
336
    val_accs, test_accs = [], []
337

338
339
340
    for i in range(args.n_runs):
        seed(args.seed + i)
        val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1)
341
342
343
        val_accs.append(val_acc)
        test_accs.append(test_acc)

344
    print(args)
345
346
347
348
349
350
351
352
353
354
355
356
    print(f"Runned {args.n_runs} times")
    print("Val Accs:", val_accs)
    print("Test Accs:", test_accs)
    print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
    print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
    print(f"Number of params: {count_parameters(args)}")


if __name__ == "__main__":
    main()


357
# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.1, gpu=0, input_drop=0.1, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=0, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)
358
# Runned 10 times
359
360
361
362
363
364
365
366
367
368
369
370
371
# Val Accs: [0.7492868888217725, 0.7524413570925199, 0.7505620993993087, 0.7500251686298198, 0.7501929594952851, 0.7513003792073559, 0.7516695191113796, 0.7505285412262156, 0.7504949830531226, 0.7515017282459143]
# Test Accs: [0.7366829208073575, 0.7384112091846182, 0.7368886694236981, 0.7345019854741477, 0.7373001666563792, 0.7362508487130424, 0.7352221056313396, 0.736477172191017, 0.7380614365368393, 0.7362919984363105]
# Average val accuracy: 0.7508003624282694 ± 0.0008760483047616948
# Average test accuracy: 0.736608851305475 ± 0.0011192876013651112
# Number of params: 1441580

# Namespace(attn_drop=0.0, cpu=False, dropout=0.75, edge_drop=0.3, gpu=0, input_drop=0.25, log_every=20, lr=0.002, n_epochs=2000, n_heads=3, n_hidden=250, n_label_iters=1, n_layers=3, n_runs=10, no_attn_dst=True, plot_curves=True, use_labels=True, use_norm=True, wd=0)
# Runned 20 times
# Val Accs: [0.7529782878620088, 0.7521393335346823, 0.7521728917077755, 0.7504949830531226, 0.7518037518037518, 0.7518373099768448, 0.7516359609382866, 0.7511325883418907, 0.7509312393033323, 0.7515017282459143, 0.7511325883418907, 0.7514346118997282, 0.7509312393033323, 0.7521393335346823, 0.7528776133427296, 0.7522735662270545, 0.7504949830531226, 0.7522735662270545, 0.7511661465149837, 0.7501258431490989]
# Test Accs: [0.7390901796185421, 0.7398720243606361, 0.7394605271279551, 0.7384523589078863, 0.7388638561405675, 0.7397280003291978, 0.7414151389831903, 0.7376499393041582, 0.7399748986688065, 0.7400366232537087, 0.7392547785116145, 0.7388844310022015, 0.7374853404110857, 0.7384317840462523, 0.7418677859391396, 0.737937987367035, 0.7381643108450096, 0.7399543238071724, 0.7377322387506944, 0.7385758080776906]
# Average val accuracy: 0.7515738783180644 ± 0.0007617982474634186
# Average test accuracy: 0.7391416167726272 ± 0.0011522198067958794
# Number of params: 1441580