"vscode:/vscode.git/clone" did not exist on "fbc3369b266fa39ae47d4553ebe95468dada1de3"
gcn_ns_sc.py 6.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data


class NodeUpdate(gluon.Block):
13
    def __init__(self, in_feats, out_feats, activation=None, concat=False):
14
15
16
17
18
19
20
        super(NodeUpdate, self).__init__()
        self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
        self.activation = activation
        self.concat = concat

    def forward(self, node):
        h = node.data['h']
21
        h = h * node.data['norm']
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        h = self.dense(h)
        # skip connection
        if self.concat:
            h = mx.nd.concat(h, self.activation(h))
        elif self.activation:
            h = self.activation(h)
        return {'activation': h}


class GCNSampling(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 **kwargs):
        super(GCNSampling, self).__init__(**kwargs)
        self.dropout = dropout
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
            # output layer
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            if self.dropout:
                h = mx.nd.Dropout(h, p=self.dropout)
            nf.layers[i].data['h'] = h
64
65
            degs = nf.layer_in_degree(i + 1).astype('float32').as_in_context(h.context)
            nf.layers[i + 1].data['norm'] = mx.nd.expand_dims(1./degs, 1)
66
67
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
68
                             fn.sum(msg='m', out='h'),
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                             layer)

        h = nf.layers[-1].data.pop('activation')
        return h


class GCNInfer(gluon.Block):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 **kwargs):
        super(GCNInfer, self).__init__(**kwargs)
        self.n_layers = n_layers
        with self.name_scope():
            self.layers = gluon.nn.Sequential()
            # input layer
            skip_start = (0 == n_layers-1)
89
            self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
90
91
92
            # hidden layers
            for i in range(1, n_layers):
                skip_start = (i == n_layers-1)
93
                self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
94
            # output layer
95
            self.layers.add(NodeUpdate(2*n_hidden, n_classes))
96
97
98
99
100
101
102
103
104
105
106
107
108


    def forward(self, nf):
        nf.layers[0].data['activation'] = nf.layers[0].data['features']

        for i, layer in enumerate(self.layers):
            h = nf.layers[i].data.pop('activation')
            nf.layers[i].data['h'] = h
            nf.block_compute(i,
                             fn.copy_src(src='h', out='m'),
                             fn.sum(msg='m', out='h'),
                             layer)

109
        return nf.layers[-1].data.pop('activation')
110
111


112
def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
113
114
115
    n0_feats = g.nodes[0].data['features']
    in_feats = n0_feats.shape[1]
    g_ctx = n0_feats.context
116

117
    degs = g.in_degrees().astype('float32').as_in_context(g_ctx)
118
    norm = mx.nd.expand_dims(1./degs, 1)
119
    g.set_n_repr({'norm': norm})
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    model = GCNSampling(in_feats,
                        args.n_hidden,
                        n_classes,
                        args.n_layers,
                        mx.nd.relu,
                        args.dropout,
                        prefix='GCN')

    model.initialize(ctx=ctx)
    loss_fcn = gluon.loss.SoftmaxCELoss()

    infer_model = GCNInfer(in_feats,
                           args.n_hidden,
                           n_classes,
                           args.n_layers,
                           mx.nd.relu,
                           prefix='GCN')

    infer_model.initialize(ctx=ctx)

    # use optimizer
    print(model.collect_params())
    trainer = gluon.Trainer(model.collect_params(), 'adam',
                            {'learning_rate': args.lr, 'wd': args.weight_decay},
                            kvstore=mx.kv.create('local'))

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
150
151
152
153
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
                                                       args.num_neighbors,
                                                       neighbor_type='in',
                                                       shuffle=True,
154
                                                       num_workers=32,
155
156
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=train_nid):
157
            nf.copy_from_parent(ctx=ctx)
158
159
160
            # forward
            with mx.autograd.record():
                pred = model(nf)
161
                batch_nids = nf.layer_parent_nid(-1)
162
                batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
163
164
165
166
167
168
169
170
171
172
173
174
175
                loss = loss_fcn(pred, batch_labels)
                loss = loss.sum() / len(batch_nids)

            loss.backward()
            trainer.step(batch_size=1)

        infer_params = infer_model.collect_params()

        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())

        num_acc = 0.
176
        num_tests = 0
177

178
179
180
181
182
        for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
                                                       g.number_of_nodes(),
                                                       neighbor_type='in',
                                                       num_hops=args.n_layers+1,
                                                       seed_nodes=test_nid):
183
            nf.copy_from_parent(ctx=ctx)
184
            pred = infer_model(nf)
185
            batch_nids = nf.layer_parent_nid(-1)
186
            batch_labels = g.nodes[batch_nids].data['labels'].as_in_context(ctx)
187
            num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
188
189
            num_tests += nf.layer_size(-1)
            break
190

191
        print("Test Accuracy {:.4f}". format(num_acc/num_tests))