train.py 5.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import argparse
import time
13

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14
15
import dgl

16
import mxnet as mx
17
import networkx as nx
18
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
20
21
22
23
24
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)
25
from gat import GAT
26
from mxnet import gluon
VoVAllen's avatar
VoVAllen committed
27
from utils import EarlyStopping
28

29

30
def elu(data):
31
    return mx.nd.LeakyReLU(data, act_type="elu")
32
33
34
35
36
37
38
39
40
41
42
43
44


def evaluate(model, features, labels, mask):
    logits = model(features)
    logits = logits[mask].asnumpy().squeeze()
    val_labels = labels[mask].asnumpy().squeeze()
    max_index = np.argmax(logits, axis=1)
    accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)
    return accuracy


def main(args):
    # load and preprocess dataset
45
    if args.dataset == "cora":
46
        data = CoraGraphDataset()
47
    elif args.dataset == "citeseer":
48
        data = CiteseerGraphDataset()
49
    elif args.dataset == "pubmed":
50
51
        data = PubmedGraphDataset()
    else:
52
        raise ValueError("Unknown dataset: {}".format(args.dataset))
53

54
55
56
57
58
59
60
61
62
    g = data[0]
    if args.gpu < 0:
        cuda = False
        ctx = mx.cpu(0)
    else:
        cuda = True
        ctx = mx.gpu(args.gpu)
        g = g.to(ctx)

63
64
65
    features = g.ndata["feat"]
    labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
    mask = g.ndata["train_mask"]
66
    mask = mx.nd.array(np.nonzero(mask.asnumpy())[0], ctx=ctx)
67
68
69
70
    val_mask = g.ndata["val_mask"]
    val_mask = mx.nd.array(np.nonzero(val_mask.asnumpy())[0], ctx=ctx)
    test_mask = g.ndata["test_mask"]
    test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)
71
72
73
74
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

75
76
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)
77
78
    # create model
    heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
79
80
81
82
83
84
85
86
87
88
89
90
91
    model = GAT(
        g,
        args.num_layers,
        in_feats,
        args.num_hidden,
        n_classes,
        heads,
        elu,
        args.in_drop,
        args.attn_drop,
        args.alpha,
        args.residual,
    )
92

93
94
    if args.early_stop:
        stopper = EarlyStopping(patience=100)
95
96
97
    model.initialize(ctx=ctx)

    # use optimizer
98
99
100
    trainer = gluon.Trainer(
        model.collect_params(), "adam", {"learning_rate": args.lr}
    )
101
102
103
104
105
106
107
108

    dur = []
    for epoch in range(args.epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            logits = model(features)
109
110
111
            loss = mx.nd.softmax_cross_entropy(
                logits[mask].squeeze(), labels[mask].squeeze()
            )
112
113
114
115
116
            loss.backward()
        trainer.step(mask.shape[0])

        if epoch >= 3:
            dur.append(time.time() - t0)
117
118
119
120
121
122
123
124
        print(
            "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
                epoch,
                loss.asnumpy()[0],
                np.mean(dur),
                n_edges / np.mean(dur) / 1000,
            )
        )
VoVAllen's avatar
VoVAllen committed
125
126
        val_accuracy = evaluate(model, features, labels, val_mask)
        print("Validation Accuracy {:.4f}".format(val_accuracy))
127
        if args.early_stop:
128
            if stopper.step(val_accuracy, model):
129
130
131
132
                break
    print()

    if args.early_stop:
133
        model.load_parameters("model.param")
134
135
136
137
    test_accuracy = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(test_accuracy))


138
139
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GAT")
140
    register_data_args(parser)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    parser.add_argument(
        "--gpu",
        type=int,
        default=-1,
        help="which GPU to use. Set -1 to use CPU.",
    )
    parser.add_argument(
        "--epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--num-heads",
        type=int,
        default=8,
        help="number of hidden attention heads",
    )
    parser.add_argument(
        "--num-out-heads",
        type=int,
        default=1,
        help="number of output attention heads",
    )
    parser.add_argument(
        "--num-layers", type=int, default=1, help="number of hidden layers"
    )
    parser.add_argument(
        "--num-hidden", type=int, default=8, help="number of hidden units"
    )
    parser.add_argument(
        "--residual",
        action="store_true",
        default=False,
        help="use residual connection",
    )
    parser.add_argument(
        "--in-drop", type=float, default=0.6, help="input feature dropout"
    )
    parser.add_argument(
        "--attn-drop", type=float, default=0.6, help="attention dropout"
    )
    parser.add_argument("--lr", type=float, default=0.005, help="learning rate")
    parser.add_argument(
        "--weight-decay", type=float, default=5e-4, help="weight decay"
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.2,
        help="the negative slop of leaky relu",
    )
    parser.add_argument(
        "--early-stop",
        action="store_true",
        default=False,
        help="indicates whether to use early stop or not",
    )
196
197
198
199
    args = parser.parse_args()
    print(args)

    main(args)