train.py 5.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import argparse
import time
13

14
import mxnet as mx
15
import networkx as nx
16
17
import numpy as np
from gat import GAT
18
from mxnet import gluon
VoVAllen's avatar
VoVAllen committed
19
from utils import EarlyStopping
20

21
22
23
24
25
import dgl
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
                      PubmedGraphDataset, register_data_args)


26
def elu(data):
27
    return mx.nd.LeakyReLU(data, act_type="elu")
28
29
30
31
32
33
34
35
36
37
38
39
40


def evaluate(model, features, labels, mask):
    logits = model(features)
    logits = logits[mask].asnumpy().squeeze()
    val_labels = labels[mask].asnumpy().squeeze()
    max_index = np.argmax(logits, axis=1)
    accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)
    return accuracy


def main(args):
    # load and preprocess dataset
41
    if args.dataset == "cora":
42
        data = CoraGraphDataset()
43
    elif args.dataset == "citeseer":
44
        data = CiteseerGraphDataset()
45
    elif args.dataset == "pubmed":
46
47
        data = PubmedGraphDataset()
    else:
48
        raise ValueError("Unknown dataset: {}".format(args.dataset))
49

50
51
52
53
54
55
56
57
58
    g = data[0]
    if args.gpu < 0:
        cuda = False
        ctx = mx.cpu(0)
    else:
        cuda = True
        ctx = mx.gpu(args.gpu)
        g = g.to(ctx)

59
60
61
    features = g.ndata["feat"]
    labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
    mask = g.ndata["train_mask"]
62
    mask = mx.nd.array(np.nonzero(mask.asnumpy())[0], ctx=ctx)
63
64
65
66
    val_mask = g.ndata["val_mask"]
    val_mask = mx.nd.array(np.nonzero(val_mask.asnumpy())[0], ctx=ctx)
    test_mask = g.ndata["test_mask"]
    test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)
67
68
69
70
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

71
72
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
73
74
    # create model
    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
75
76
77
78
79
80
81
82
83
84
85
86
87
    model = GAT(
        g,
        args.num_layers,
        in_feats,
        args.num_hidden,
        n_classes,
        heads,
        elu,
        args.in_drop,
        args.attn_drop,
        args.alpha,
        args.residual,
    )
88

89
90
    if args.early_stop:
        stopper = EarlyStopping(patience=100)
91
92
93
    model.initialize(ctx=ctx)

    # use optimizer
94
95
96
    trainer = gluon.Trainer(
        model.collect_params(), "adam", {"learning_rate": args.lr}
    )
97
98
99
100
101
102
103
104

    dur = []
    for epoch in range(args.epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            logits = model(features)
105
106
107
            loss = mx.nd.softmax_cross_entropy(
                logits[mask].squeeze(), labels[mask].squeeze()
            )
108
109
110
111
112
            loss.backward()
        trainer.step(mask.shape[0])

        if epoch >= 3:
            dur.append(time.time() - t0)
113
114
115
116
117
118
119
120
        print(
            "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
                epoch,
                loss.asnumpy()[0],
                np.mean(dur),
                n_edges / np.mean(dur) / 1000,
            )
        )
VoVAllen's avatar
VoVAllen committed
121
122
        val_accuracy = evaluate(model, features, labels, val_mask)
        print("Validation Accuracy {:.4f}".format(val_accuracy))
123
        if args.early_stop:
124
            if stopper.step(val_accuracy, model):
125
126
127
128
                break
    print()

    if args.early_stop:
129
        model.load_parameters("model.param")
130
131
132
133
    test_accuracy = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(test_accuracy))


134
if __name__ == "__main__":
135

136
    parser = argparse.ArgumentParser(description="GAT")
137
    register_data_args(parser)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    parser.add_argument(
        "--gpu",
        type=int,
        default=-1,
        help="which GPU to use. Set -1 to use CPU.",
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--num-heads",
        type=int,
        default=8,
        help="number of hidden attention heads",
    )
    parser.add_argument(
        "--num-out-heads",
        type=int,
        default=1,
        help="number of output attention heads",
    )
    parser.add_argument(
        "--num-layers", type=int, default=1, help="number of hidden layers"
    )
    parser.add_argument(
        "--num-hidden", type=int, default=8, help="number of hidden units"
    )
    parser.add_argument(
        "--residual",
        action="store_true",
        default=False,
        help="use residual connection",
    )
    parser.add_argument(
        "--in-drop", type=float, default=0.6, help="input feature dropout"
    )
    parser.add_argument(
        "--attn-drop", type=float, default=0.6, help="attention dropout"
    )
    parser.add_argument("--lr", type=float, default=0.005, help="learning rate")
    parser.add_argument(
        "--weight-decay", type=float, default=5e-4, help="weight decay"
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.2,
        help="the negative slop of leaky relu",
    )
    parser.add_argument(
        "--early-stop",
        action="store_true",
        default=False,
        help="indicates whether to use early stop or not",
    )
193
194
195
196
    args = parser.parse_args()
    print(args)

    main(args)