gcn_ns_sc.py 6.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data


class NodeUpdate(gluon.Block):
13
    def __init__(self, in_feats, out_feats, activation=None, concat=False):
14
15
16
17
18
19
20
        super(NodeUpdate, self).__init__()
        self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
        self.activation = activation
        self.concat = concat

    def forward(self, node):
        h = node.data['h']
21
        h = h * node.data['norm']
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        h = self.dense(h)
        # skip connection
        if self.concat:
            h = mx.nd.concat(h, self.activation(h))
        elif self.activation:
            h = self.activation(h)
        return {'activation': h}


class GCNSampling(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 **kwargs):
        super(GCNSampling, self).__init__(**kwargs)
        self.dropout = dropout
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
            # output layer
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            if self.dropout:
                h = mx.nd.Dropout(h, p=self.dropout)
            nf.layers[i].data['h'] = h
64
65
            degs = nf.layer_in_degree(i + 1).astype('float32').as_in_context(h.context)
            nf.layers[i + 1].data['norm'] = mx.nd.expand_dims(1./degs, 1)
66
67
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
68
                             fn.sum(msg='m', out='h'),
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                             layer)

        h = nf.layers[-1].data.pop('activation')
        return h


class GCNInfer(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 **kwargs):
        super(GCNInfer, self).__init__(**kwargs)
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
89
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
90
91
92
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
93
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
94
            # output layer
95
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))
96
97
98
99
100
101
102
103
104
105
106
107
108


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            nf.layers[i].data['h'] = h
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
                             fn.sum(msg='m', out='h'),
                             layer)

109
        return nf.layers[-1].data.pop('activation')
110
111


112
113
114
def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
    in_feats = g.ndata['features'].shape[1]
    labels = g.ndata['labels']
115
    g_ctx = labels.context
116

117
    degs = g.in_degrees().astype('float32').as_in_context(g_ctx)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    norm = mx.nd.expand_dims(1./degs, 1)
    g.ndata['norm'] = norm

    model = GCNSampling(in_feats,
                        args.n_hidden,
                        n_classes,
                        args.n_layers,
                        mx.nd.relu,
                        args.dropout,
                        prefix='GCN')

    model.initialize(ctx=ctx)
    loss_fcn = gluon.loss.SoftmaxCELoss()

    infer_model = GCNInfer(in_feats,
                           args.n_hidden,
                           n_classes,
                           args.n_layers,
                           mx.nd.relu,
                           prefix='GCN')

    infer_model.initialize(ctx=ctx)

    # use optimizer
    print(model.collect_params())
    trainer = gluon.Trainer(model.collect_params(), 'adam',
                            {'learning_rate': args.lr, 'wd': args.weight_decay},
                            kvstore=mx.kv.create('local'))

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
150
151
152
153
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
                                                       args.num_neighbors,
                                                       neighbor_type='in',
                                                       shuffle=True,
154
                                                       num_workers=32,
155
156
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=train_nid):
157
            nf.copy_from_parent(ctx=ctx)
158
159
160
            # forward
            with mx.autograd.record():
                pred = model(nf)
161
162
                batch_nids = nf.layer_parent_nid(-1)
                batch_labels = labels[batch_nids].as_in_context(ctx)
163
164
165
166
167
168
169
170
171
172
173
174
175
                loss = loss_fcn(pred, batch_labels)
                loss = loss.sum() / len(batch_nids)

            loss.backward()
            trainer.step(batch_size=1)

        infer_params = infer_model.collect_params()

        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())

        num_acc = 0.
176
        num_tests = 0
177

178
179
180
181
182
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
                                                       g.number_of_nodes(),
                                                       neighbor_type='in',
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=test_nid):
183
            nf.copy_from_parent(ctx=ctx)
184
            pred = infer_model(nf)
185
186
            batch_nids = nf.layer_parent_nid(-1)
            batch_labels = labels[batch_nids].as_in_context(ctx)
187
            num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
188
189
            num_tests += nf.layer_size(-1)
            break
190

191
        print("Test Accuracy {:.4f}". format(num_acc/num_tests))