citation.py 6.03 KB
Newer Older
1
2
import argparse
import time
3

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
4
5
import dgl

6
import mxnet as mx
7
8
import networkx as nx
import numpy as np
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
10
11
12
13
14
15
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)
from dgl.nn.mxnet.conv import GMMConv
16
17
from mxnet import gluon, nd
from mxnet.gluon import nn
18

19
20

class MoNet(nn.Block):
21
22
23
24
25
26
27
28
29
30
31
    def __init__(
        self,
        g,
        in_feats,
        n_hidden,
        out_feats,
        n_layers,
        dim,
        n_kernels,
        dropout,
    ):
32
33
34
35
36
37
38
        super(MoNet, self).__init__()
        self.g = g
        with self.name_scope():
            self.layers = nn.Sequential()
            self.pseudo_proj = nn.Sequential()

            # Input layer
39
40
            self.layers.add(GMMConv(in_feats, n_hidden, dim, n_kernels))
            self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
41
42
43
44

            # Hidden layer
            for _ in range(n_layers - 1):
                self.layers.add(GMMConv(n_hidden, n_hidden, dim, n_kernels))
45
46
47
                self.pseudo_proj.add(
                    nn.Dense(dim, in_units=2, activation="tanh")
                )
48
49
50

            # Output layer
            self.layers.add(GMMConv(n_hidden, out_feats, dim, n_kernels))
51
            self.pseudo_proj.add(nn.Dense(dim, in_units=2, activation="tanh"))
52

Zihao Ye's avatar
Zihao Ye committed
53
54
            self.dropout = nn.Dropout(dropout)

55
56
57
    def forward(self, feat, pseudo):
        h = feat
        for i in range(len(self.layers)):
Zihao Ye's avatar
Zihao Ye committed
58
59
            if i > 0:
                h = self.dropout(h)
60
            h = self.layers[i](self.g, h, self.pseudo_proj[i](pseudo))
61
62
63
64
65
66
67
68
        return h


def evaluate(model, features, pseudo, labels, mask):
    pred = model(features, pseudo).argmax(axis=1)
    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()
    return accuracy.asscalar()

69

70
71
def main(args):
    # load and preprocess dataset
72
    if args.dataset == "cora":
73
        data = CoraGraphDataset()
74
    elif args.dataset == "citeseer":
75
        data = CiteseerGraphDataset()
76
    elif args.dataset == "pubmed":
77
78
        data = PubmedGraphDataset()
    else:
79
        raise ValueError("Unknown dataset: {}".format(args.dataset))
80
81
82
83
84
85
86
87
88

    g = data[0]
    if args.gpu < 0:
        cuda = False
        ctx = mx.cpu(0)
    else:
        cuda = True
        ctx = mx.gpu(args.gpu)
        g = g.to(ctx)
89

90
91
92
93
94
    features = g.ndata["feat"]
    labels = mx.nd.array(g.ndata["label"], dtype="float32", ctx=ctx)
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
95
96
97
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
98
99
    print(
        """----Data statistics------'
100
101
102
103
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
104
105
106
107
108
109
110
111
112
      #Test samples %d"""
        % (
            n_edges,
            n_classes,
            train_mask.sum().asscalar(),
            val_mask.sum().asscalar(),
            test_mask.sum().asscalar(),
        )
    )
113

114
115
116
117
    # add self loop
    g = dgl.remove_self_loop(g)
    g = dgl.add_self_loop(g)

118
119
    n_edges = g.number_of_edges()
    us, vs = g.edges()
120
121
    us = us.asnumpy()
    vs = vs.asnumpy()
122
123
    pseudo = []
    for i in range(g.number_of_edges()):
124
        pseudo.append(
125
            [1 / np.sqrt(g.in_degrees(us[i])), 1 / np.sqrt(g.in_degrees(vs[i]))]
126
        )
127
128
129
    pseudo = nd.array(pseudo, ctx=ctx)

    # create GraphSAGE model
130
131
132
133
134
135
136
137
138
139
    model = MoNet(
        g,
        in_feats,
        args.n_hidden,
        n_classes,
        args.n_layers,
        args.pseudo_dim,
        args.n_kernels,
        args.dropout,
    )
140
141
142
143
144
    model.initialize(ctx=ctx)
    n_train_samples = train_mask.sum().asscalar()
    loss_fcn = gluon.loss.SoftmaxCELoss()

    print(model.collect_params())
145
146
147
148
149
    trainer = gluon.Trainer(
        model.collect_params(),
        "adam",
        {"learning_rate": args.lr, "wd": args.weight_decay},
    )
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
        with mx.autograd.record():
            pred = model(features, pseudo)
            loss = loss_fcn(pred, labels, mx.nd.expand_dims(train_mask, 1))
            loss = loss.sum() / n_train_samples

        loss.backward()
        trainer.step(batch_size=1)

        if epoch >= 3:
            loss.asscalar()
            dur.append(time.time() - t0)
            acc = evaluate(model, features, pseudo, labels, val_mask)
169
170
171
172
173
174
175
176
177
178
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.asscalar(),
                    acc,
                    n_edges / np.mean(dur) / 1000,
                )
            )
179
180
181
182
183
184

    # test set accuracy
    acc = evaluate(model, features, pseudo, labels, test_mask)
    print("Test accuracy {:.2%}".format(acc))


185
186
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MoNet on citation network")
187
    register_data_args(parser)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    parser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
    parser.add_argument(
        "--n-epochs", type=int, default=200, help="number of training epochs"
    )
    parser.add_argument(
        "--n-hidden", type=int, default=16, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--pseudo-dim",
        type=int,
        default=2,
        help="Pseudo coordinate dimensions in GMMConv, 2 for cora and 3 for pubmed",
    )
    parser.add_argument(
        "--n-kernels",
        type=int,
        default=3,
        help="Number of kernels in GMMConv layer",
    )
    parser.add_argument(
        "--weight-decay", type=float, default=5e-5, help="Weight for L2 loss"
    )
217
218
219
    args = parser.parse_args()
    print(args)

220
    main(args)