train.py 4.56 KB
Newer Older
1
"""Training GCN model on citation graphs."""
2
3
4
import argparse
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
import mxnet as mx
8
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
10
from gcn import GCN
11
12
from mxnet import gluon

13
14
# from gcn_mp import GCN
# from gcn_spmv import GCN
15
16
17
18
19
20
21


def evaluate(model, features, labels, mask):
    pred = model(features).argmax(axis=1)
    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
    return accuracy.asscalar()

22

23
24
def main(args):
    # load and preprocess dataset
25
    if args.dataset == "cora":
26
        data = CoraGraphDataset()
27
    elif args.dataset == "citeseer":
28
        data = CiteseerGraphDataset()
29
    elif args.dataset == "pubmed":
30
31
        data = PubmedGraphDataset()
    else:
32
        raise ValueError("Unknown dataset: {}".format(args.dataset))
33
34
35
36
37
38
39
40

    g = data[0]
    if args.gpu < 0:
        cuda = False
        ctx = mx.cpu(0)
    else:
        cuda = True
        ctx = mx.gpu(args.gpu)
41
        g = g.int().to(ctx)
42

43
44
45
46
47
    features = g.ndata["feat"]
    labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
48
49
50
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
51
52
    print(
        """----Data statistics------'
53
54
55
56
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
57
58
59
60
61
62
63
64
65
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.sum().asscalar(),
            val_mask.sum().asscalar(),
            test_mask.sum().asscalar(),
        )
    )
66

67
    # add self loop
Da Zheng's avatar
Da Zheng committed
68
    if args.self_loop:
69
70
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
71
    # normalization
72
    degs = g.in_degrees().astype("float32")
73
74
75
    norm = mx.nd.power(degs, -0.5)
    if cuda:
        norm = norm.as_in_context(ctx)
76
77
78
79
80
81
82
83
84
85
86
    g.ndata["norm"] = mx.nd.expand_dims(norm, 1)

    model = GCN(
        g,
        in_feats,
        args.n_hidden,
        n_classes,
        args.n_layers,
        mx.nd.relu,
        args.dropout,
    )
87
88
89
90
91
92
    model.initialize(ctx=ctx)
    n_train_samples = train_mask.sum().asscalar()
    loss_fcn = gluon.loss.SoftmaxCELoss()

    # use optimizer
    print(model.collect_params())
93
94
95
96
97
    trainer = gluon.Trainer(
        model.collect_params(),
        "adam",
        {"learning_rate": args.lr, "wd": args.weight_decay},
    )
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            pred = model(features)
            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
            loss = loss.sum() / n_train_samples

        loss.backward()
        trainer.step(batch_size=1)

        if epoch >= 3:
            loss.asscalar()
            dur.append(time.time() - t0)
            acc = evaluate(model, features, labels, val_mask)
117
118
119
120
121
122
123
124
125
126
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.asscalar(),
                    acc,
                    n_edges / np.mean(dur) / 1000,
                )
            )
127
128
129
130
131

    # test set accuracy
    acc = evaluate(model, features, labels, test_mask)
    print("Test accuracy {:.2%}".format(acc))

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed').",
    )
    parser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=3e-2, help="learning rate")
    parser.add_argument(
        "--n-epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
    )
    parser.add_argument(
        "--self-loop",
        action="store_true",
        help="graph self-loop (default=False)",
    )
Da Zheng's avatar
Da Zheng committed
163
    parser.set_defaults(self_loop=False)
164
165
166
167
168
    args = parser.parse_args()

    print(args)

    main(args)