gcn_ns_sc.py 6.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data


class NodeUpdate(gluon.Block):
13
    def __init__(self, in_feats, out_feats, activation=None, concat=False):
14
15
16
17
18
19
20
        super(NodeUpdate, self).__init__()
        self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
        self.activation = activation
        self.concat = concat

    def forward(self, node):
        h = node.data['h']
21
        h = h * node.data['norm']
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        h = self.dense(h)
        # skip connection
        if self.concat:
            h = mx.nd.concat(h, self.activation(h))
        elif self.activation:
            h = self.activation(h)
        return {'activation': h}


class GCNSampling(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 **kwargs):
        super(GCNSampling, self).__init__(**kwargs)
        self.dropout = dropout
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
            # output layer
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            if self.dropout:
                h = mx.nd.Dropout(h, p=self.dropout)
            nf.layers[i].data['h'] = h
64
65
            degs = nf.layer_in_degree(i + 1).astype('float32').as_in_context(h.context)
            nf.layers[i + 1].data['norm'] = mx.nd.expand_dims(1./degs, 1)
66
67
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
68
                             fn.sum(msg='m', out='h'),
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                             layer)

        h = nf.layers[-1].data.pop('activation')
        return h


class GCNInfer(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 **kwargs):
        super(GCNInfer, self).__init__(**kwargs)
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
89
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
90
91
92
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
93
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
94
            # output layer
95
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))
96
97
98
99
100
101
102
103
104
105
106
107
108


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            nf.layers[i].data['h'] = h
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
                             fn.sum(msg='m', out='h'),
                             layer)

109
        return nf.layers[-1].data.pop('activation')
110
111


112
113
114
def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
    in_feats = g.ndata['features'].shape[1]
    labels = g.ndata['labels']
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    degs = g.in_degrees().astype('float32').as_in_context(ctx)
    norm = mx.nd.expand_dims(1./degs, 1)
    g.ndata['norm'] = norm

    model = GCNSampling(in_feats,
                        args.n_hidden,
                        n_classes,
                        args.n_layers,
                        mx.nd.relu,
                        args.dropout,
                        prefix='GCN')

    model.initialize(ctx=ctx)
    loss_fcn = gluon.loss.SoftmaxCELoss()

    infer_model = GCNInfer(in_feats,
                           args.n_hidden,
                           n_classes,
                           args.n_layers,
                           mx.nd.relu,
                           prefix='GCN')

    infer_model.initialize(ctx=ctx)

    # use optimizer
    print(model.collect_params())
    trainer = gluon.Trainer(model.collect_params(), 'adam',
                            {'learning_rate': args.lr, 'wd': args.weight_decay},
                            kvstore=mx.kv.create('local'))

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
149
150
151
152
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
                                                       args.num_neighbors,
                                                       neighbor_type='in',
                                                       shuffle=True,
153
                                                       num_workers=32,
154
155
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=train_nid):
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            nf.copy_from_parent()
            # forward
            with mx.autograd.record():
                pred = model(nf)
                batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
                batch_labels = labels[batch_nids]
                loss = loss_fcn(pred, batch_labels)
                loss = loss.sum() / len(batch_nids)

            loss.backward()
            trainer.step(batch_size=1)

        infer_params = infer_model.collect_params()

        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())

        num_acc = 0.
175
        num_tests = 0
176

177
178
179
180
181
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
                                                       g.number_of_nodes(),
                                                       neighbor_type='in',
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=test_nid):
182
183
184
185
186
            nf.copy_from_parent()
            pred = infer_model(nf)
            batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
            batch_labels = labels[batch_nids]
            num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
187
188
            num_tests += nf.layer_size(-1)
            break
189

190
        print("Test Accuracy {:.4f}". format(num_acc/num_tests))