gcn.py 9.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import math
import time

import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
13
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
14
15
from ogb.nodeproppred import DglNodePropPredDataset

16
from models import GCN
17

18
19
device = None
in_feats, n_classes = None, None
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)


def cross_entropy(x, labels):
    y = F.cross_entropy(x, labels, reduction="none")
    y = th.log(0.5 + y) - math.log(0.5)
    return th.mean(y)


35
36
37
38
39
40
41
def add_labels(feat, labels, idx):
    onehot = th.zeros([feat.shape[0], n_classes]).to(device)
    onehot[idx, labels[idx]] = 1
    return th.cat([feat, onehot], dim=-1)


def train(model, graph, labels, train_idx, optimizer, use_labels):
42
43
44
    model.train()

    feat = graph.ndata["feat"]
45
46
47
48
49
50
51
52
53
54
55
56
57

    if use_labels:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate

        train_labels_idx = train_idx[mask]
        train_pred_idx = train_idx[~mask]

        feat = add_labels(feat, labels, train_labels_idx)
    else:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate
        train_pred_idx = train_idx[mask]
58
59
60

    optimizer.zero_grad()
    pred = model(graph, feat)
61
    loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])
62
63
64
65
66
67
68
    loss.backward()
    optimizer.step()

    return loss, pred


@th.no_grad()
69
def evaluate(model, graph, labels, train_idx, val_idx, test_idx, use_labels):
70
71
72
    model.eval()

    feat = graph.ndata["feat"]
73
74
75
76

    if use_labels:
        feat = add_labels(feat, labels, train_idx)

77
    pred = model(graph, feat)
78
    train_loss = cross_entropy(pred[train_idx], labels[train_idx])
79
    val_loss = cross_entropy(pred[val_idx], labels[val_idx])
80
    test_loss = cross_entropy(pred[test_idx], labels[test_idx])
81
82
83
84
85

    return (
        compute_acc(pred[train_idx], labels[train_idx]),
        compute_acc(pred[val_idx], labels[val_idx]),
        compute_acc(pred[test_idx], labels[test_idx]),
86
        train_loss,
87
        val_loss,
88
        test_loss,
89
90
91
    )


92
def adjust_learning_rate(optimizer, lr, epoch):
93
94
    if epoch <= 50:
        for param_group in optimizer.param_groups:
95
            param_group["lr"] = lr * epoch / 50
96
97


98
99
100
101
102
103
104
105
def gen_model(args):
    if args.use_labels:
        model = GCN(
            in_feats + n_classes, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear
        )
    else:
        model = GCN(in_feats, args.n_hidden, n_classes, args.n_layers, F.relu, args.dropout, args.use_linear)
    return model
106

107
108
109
110

def run(args, graph, labels, train_idx, val_idx, test_idx, n_running):
    # define model and optimizer
    model = gen_model(args)
111
    model = model.to(device)
112

113
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
114
115
116
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=100, verbose=True, min_lr=1e-3
    )
117

118
    # training loop
119
    total_time = 0
120
121
122
123
    best_val_acc, best_test_acc, best_val_loss = 0, 0, float("inf")

    accs, train_accs, val_accs, test_accs = [], [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []
124
125
126
127

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

128
        adjust_learning_rate(optimizer, args.lr, epoch)
129

130
        loss, pred = train(model, graph, labels, train_idx, optimizer, args.use_labels)
131
132
        acc = compute_acc(pred[train_idx], labels[train_idx])

133
134
135
        train_acc, val_acc, test_acc, train_loss, val_loss, test_loss = evaluate(
            model, graph, labels, train_idx, val_idx, test_idx, args.use_labels
        )
136

137
        lr_scheduler.step(loss)
138

139
140
        toc = time.time()
        total_time += toc - tic
141

142
        # if val_acc > best_val_acc:
143
144
145
146
147
148
        if val_loss < best_val_loss:
            best_val_loss = val_loss.item()
            best_val_acc = val_acc.item()
            best_test_acc = test_acc.item()

        if epoch % args.log_every == 0:
149
            print(f"Epoch: {epoch}/{args.n_epochs}")
150
151
            print(
                f"Loss: {loss.item():.4f}, Acc: {acc.item():.4f}\n"
152
153
                f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
                f"Train/Val/Test/Best val/Best test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{best_test_acc:.4f}"
154
155
            )

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        for l, e in zip(
            [accs, train_accs, val_accs, test_accs, losses, train_losses, val_losses, test_losses],
            [acc, train_acc, val_acc, test_acc, loss, train_loss, val_loss, test_loss],
        ):
            l.append(e)

    print("*" * 50)
    print(f"Average epoch time: {total_time / args.n_epochs}, Test acc: {best_test_acc}")

    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip([accs, train_accs, val_accs, test_accs], ["acc", "train acc", "val acc", "test acc"]):
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_acc_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [losses, train_losses, val_losses, test_losses], ["loss", "train loss", "val loss", "test loss"]
        ):
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_loss_{n_running}.png")
200
201
202
203

    return best_val_acc, best_test_acc


204
205
206
207
208
def count_parameters(args):
    model = gen_model(args)
    return sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])


209
def main():
210
211
212
    global device, in_feats, n_classes

    argparser = argparse.ArgumentParser("GCN on OGBN-Arxiv", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
213
214
215
216
    argparser.add_argument("--cpu", action="store_true", help="CPU mode. This option overrides --gpu.")
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
    argparser.add_argument("--n-runs", type=int, default=10)
    argparser.add_argument("--n-epochs", type=int, default=1000)
217
218
219
220
221
    argparser.add_argument(
        "--use-labels", action="store_true", help="Use labels in the training set as input features."
    )
    argparser.add_argument("--use-linear", action="store_true", help="Use linear layer.")
    argparser.add_argument("--lr", type=float, default=0.005)
222
223
224
225
226
    argparser.add_argument("--n-layers", type=int, default=3)
    argparser.add_argument("--n-hidden", type=int, default=256)
    argparser.add_argument("--dropout", type=float, default=0.5)
    argparser.add_argument("--wd", type=float, default=0)
    argparser.add_argument("--log-every", type=int, default=20)
227
    argparser.add_argument("--plot-curves", action="store_true")
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    args = argparser.parse_args()

    if args.cpu:
        device = th.device("cpu")
    else:
        device = th.device("cuda:%d" % args.gpu)

    # load data
    data = DglNodePropPredDataset(name="ogbn-arxiv")
    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]
    labels = labels[:, 0]

    # add reverse edges
    srcs, dsts = graph.all_edges()
    graph.add_edges(dsts, srcs)

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    in_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    graph.create_format_()

    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    test_idx = test_idx.to(device)
    labels = labels.to(device)
    graph = graph.to(device)

    # run
    val_accs = []
    test_accs = []

    for i in range(args.n_runs):
266
        val_acc, test_acc = run(args, graph, labels, train_idx, val_idx, test_idx, i)
267
268
269
270
271
272
273
274
        val_accs.append(val_acc)
        test_accs.append(test_acc)

    print(f"Runned {args.n_runs} times")
    print("Val Accs:", val_accs)
    print("Test Accs:", test_accs)
    print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
    print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
275
    print(f"Number of params: {count_parameters(args)}")
276
277
278
279


if __name__ == "__main__":
    main()