sse_batch.py 15.4 KB
Newer Older
1
2
3
4
5
6
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
Da Zheng's avatar
Da Zheng committed
7
import random
8
9
import numpy as np
import time
10
import math
11
12
13
14
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
15
from dgl import DGLGraph
16
17
from dgl.data import register_data_args, load_data

18
19
def gcn_msg(edges):
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
20

21
def gcn_reduce(nodes):
22
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1) / nodes.mailbox['m'].shape[1]}
23
24

class NodeUpdate(gluon.Block):
25
    def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
26
        super(NodeUpdate, self).__init__(**kwargs)
27
28
29
30
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

31
32
33
34
35
36
37
38
39
40
41
42
    def forward(self, in_data, hidden_data, accum):
        tmp = mx.nd.concat(in_data, accum, dim=1)
        hidden = self.linear2(self.linear1(tmp))
        return hidden_data * (1 - self.alpha) + self.alpha * hidden

class DGLNodeUpdate(gluon.Block):
    def __init__(self, update):
        super(DGLNodeUpdate, self).__init__()
        self.update = update

    def forward(self, node):
        return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
43

44
class DGLSSEUpdateHiddenInfer(gluon.Block):
45
46
47
48
49
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
Da Zheng's avatar
Da Zheng committed
50
                 inference,
51
                 **kwargs):
52
        super(DGLSSEUpdateHiddenInfer, self).__init__(**kwargs)
53
54
55
        with self.name_scope():
            self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
56
        self.use_spmv = use_spmv
Da Zheng's avatar
Da Zheng committed
57
        self.inference = inference
58
59
60

    def forward(self, g, vertices):
        if self.use_spmv:
61
            feat = g.ndata['in']
62
            g.ndata['cat'] = mx.nd.concat(feat, g.ndata['h'], dim=1)
63

64
65
            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
66
67
68
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
Da Zheng's avatar
Da Zheng committed
69
        deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
70
        if vertices is None:
71
72
73
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
74
                g.ndata['accum'] = g.ndata['accum'] / deg
75
76
77
78
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
Da Zheng's avatar
Da Zheng committed
79
                g.apply_nodes(self.layer, vs, inplace=self.inference)
80
            g.ndata.pop('accum')
81
            return g.get_n_repr()['h1']
82
83
84
        else:
            # We don't need dropout for inference.
            if self.dropout:
85
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
Da Zheng's avatar
Da Zheng committed
86
87
            g.update_all(msg_func, reduce_func, None)
            ctx = g.ndata['accum'].context
88
89
            if self.use_spmv:
                g.ndata.pop('cat')
Da Zheng's avatar
Da Zheng committed
90
                deg = deg.as_in_context(ctx)
91
                g.ndata['accum'] = g.ndata['accum'] / deg
Da Zheng's avatar
Da Zheng committed
92
            g.apply_nodes(self.layer, vertices, inplace=self.inference)
93
            g.ndata.pop('accum')
Da Zheng's avatar
Da Zheng committed
94
            return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class DGLSSEUpdateHiddenTrain(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
                 inference,
                 **kwargs):
        super(DGLSSEUpdateHiddenTrain, self).__init__(**kwargs)
        with self.name_scope():
            self.update = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
        self.use_spmv = use_spmv
        self.inference = inference

    def forward(self, subg, vertices):
        assert vertices is not None
        if self.use_spmv:
            feat = subg.layers[0].data['in']
            subg.layers[0].data['cat'] = mx.nd.concat(feat, subg.layers[0].data['h'],
                                                      dim=1)

            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
        deg = mx.nd.expand_dims(subg.layer_in_degree(1), 1).astype(np.float32)
        # We don't need dropout for inference.
        if self.dropout:
            subg.layers[0].data['h'] = mx.nd.Dropout(subg.layers[0].data['h'], p=self.dropout)
        subg.block_compute(0, msg_func, reduce_func, None)
        ctx = subg.layers[1].data['accum'].context
        if self.use_spmv:
            subg.layers[0].data.pop('cat')
            deg = deg.as_in_context(ctx)
            subg.layers[1].data['accum'] = subg.layers[1].data['accum'] / deg
        subg.apply_layer(1, self.update, inplace=self.inference)
        subg.layers[1].data.pop('accum')
        return subg.layers[1].data['h1']

137
class SSEPredict(gluon.Block):
138
139
140
141
142
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
143
144
145
146
147
148
149
150
151
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

152
def copy_to_gpu(subg, ctx):
153
154
155
156
    for i in range(subg.num_layers):
        frame = subg.layers[i].data
        for key in frame:
            subg.layers[i].data[key] = frame[key].as_in_context(ctx)
Da Zheng's avatar
Da Zheng committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

class CachedSubgraphLoader(object):
    def __init__(self, loader, shuffle):
        self._loader = loader
        self._cached = []
        self._shuffle = shuffle

    def restart(self):
        self._subgraphs = self._cached
        self._gen_subgraph = len(self._subgraphs) == 0
        random.shuffle(self._subgraphs)
        self._cached = []

    def __iter__(self):
        return self

    def __next__(self):
        if len(self._subgraphs) > 0:
175
            subg = self._subgraphs.pop(0)
Da Zheng's avatar
Da Zheng committed
176
        elif self._gen_subgraph:
177
            subg = self._loader.__next__()
Da Zheng's avatar
Da Zheng committed
178
179
        else:
            raise StopIteration
180
181
182
183
184
185

        # We can't cache the input subgraph because it contains node frames
        # and data frames.
        subg = dgl.NodeFlow(subg._parent, subg._graph)
        self._cached.append(subg)
        return subg
Da Zheng's avatar
Da Zheng committed
186

187
def main(args, data):
188
189
190
191
192
193
194
195
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
196
197
198
199
200
201
202
203
    if data.train_mask is not None:
        train_vs = mx.nd.array(np.nonzero(data.train_mask)[0], dtype='int64')
        eval_vs = mx.nd.array(np.nonzero(data.train_mask == 0)[0], dtype='int64')
    else:
        train_size = len(labels) * args.train_percent
        train_vs = mx.nd.arange(0, train_size, dtype='int64')
        eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')

204
205
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
Da Zheng's avatar
Da Zheng committed
206
    eval_labels = mx.nd.take(labels, eval_vs)
207
208
209
210
211
212
213
214
215
    in_feats = features.shape[1]
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
216
217
218
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
219

220
221
222
223
224
225
    update_hidden_infer = DGLSSEUpdateHiddenInfer(args.n_hidden, 'relu',
                                                  args.update_dropout, args.use_spmv,
                                                  inference=True, prefix='sse')
    update_hidden_train = DGLSSEUpdateHiddenTrain(args.n_hidden, 'relu',
                                                  args.update_dropout, args.use_spmv,
                                                  inference=False, prefix='sse')
226

227
228
229
    model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer.initialize(ctx=mx.cpu(0))
230
    if args.gpu <= 0:
231
        model_train.initialize(ctx=mx.cpu(0))
232
    else:
233
        train_ctxs = []
234
235
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
236
        model_train.initialize(ctx=train_ctxs)
237
238
239
240
241

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
242
    trainer = gluon.Trainer(model_train.collect_params(), 'adam', {'learning_rate': args.lr,
243
244
245
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
246
247
248
249
    all_hidden = update_hidden_infer(g, None)
    g.ndata['h'] = all_hidden
    rets = []
    rets.append(all_hidden)
250

Da Zheng's avatar
Da Zheng committed
251
252
253
254
    if args.neigh_expand <= 0:
        neigh_expand = g.number_of_nodes()
    else:
        neigh_expand = args.neigh_expand
255
256
    # initialize graph
    dur = []
Da Zheng's avatar
Da Zheng committed
257
258
    sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
            neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
Da Zheng's avatar
Da Zheng committed
259
            shuffle=True)
260
261
    if args.cache_subgraph:
        sampler = CachedSubgraphLoader(sampler, shuffle=True)
262
263
264
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
265
        i = 0
266
267
        num_batches = len(train_vs) / args.batch_size
        start1 = time.time()
268
269
270
        for subg in sampler:
            seeds = subg.layer_parent_nid(-1)
            subg_seeds = subg.layer_nid(-1)
Da Zheng's avatar
Da Zheng committed
271
            subg.copy_from_parent()
272
273

            losses = []
Da Zheng's avatar
Da Zheng committed
274
275
276
277
278
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            with mx.autograd.record():
279
                logits = model_train(subg, subg_seeds)
Da Zheng's avatar
Da Zheng committed
280
                batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
Da Zheng's avatar
Da Zheng committed
281
282
283
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
284
285
286
287
288
289
            i += 1
            if args.gpu <= 0:
                trainer.step(seeds.shape[0])
                train_loss += loss.asnumpy()[0]
                losses = []
            elif i % args.gpu == 0:
Da Zheng's avatar
Da Zheng committed
290
291
292
293
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
294

295
296
297
298
299
300
            if i % args.num_parallel_subgraphs == 0:
                end1 = time.time()
                print("process " + str(args.num_parallel_subgraphs)
                        + " subgraphs takes " + str(end1 - start1))
                start1 = end1

301
302
303
304
305
306
        if args.cache_subgraph:
            sampler.restart()
        else:
            sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
                                                           neighbor_type='in',
                                                           num_workers=args.num_parallel_subgraphs,
307
                                                           seed_nodes=train_vs, shuffle=True)
308

309
        # test set accuracy
Da Zheng's avatar
Da Zheng committed
310
        logits = model_infer(g, eval_vs)
311
312
313
314
        y_bar = mx.nd.argmax(logits, axis=1)
        y = eval_labels
        accuracy = mx.nd.sum(y_bar == y) / len(y)
        accuracy = accuracy.asnumpy()[0]
315
316
317

        # update the inference model.
        infer_params = model_infer.collect_params()
318
319
320
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
321
322
323
324
325

        # Update node embeddings.
        all_hidden = update_hidden_infer(g, None)
        g.ndata['h'] = all_hidden
        rets.append(all_hidden)
326

327
        dur.append(time.time() - t0)
328
329
        print("Epoch {:05d} | Train Loss {:.4f} | Test Accuracy {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, accuracy, np.mean(dur), n_edges / np.mean(dur) / 1000))
330

331
332
    return rets

333
334
335
336
337
338
339
340
341
342
343
344
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
345
        return mx.nd.contrib.getnnz(self._mat).asnumpy()[0]
346
347
348
349
350
351
352
353

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
Da Zheng's avatar
Da Zheng committed
354
        self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
355
        self.train_mask = None
356

357
358
359
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
360
361
362
363
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
364
365
366
367
368
369
370
371
372
373
374
375
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
376
    parser.add_argument("--update-dropout", type=float, default=0,
377
            help="the dropout rate for updating vertex embedding")
378
    parser.add_argument("--predict-dropout", type=float, default=0,
379
380
381
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
382
    parser.add_argument("--use-spmv", action="store_true",
383
            help="use SpMV for faster speed.")
384
    parser.add_argument("--cache-subgraph", default=False, action="store_false")
385
386
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
Da Zheng's avatar
Da Zheng committed
387
388
    parser.add_argument("--neigh-expand", type=int, default=16,
            help="the number of neighbors to sample.")
389
    args = parser.parse_args()
390
    print("cache: " + str(args.cache_subgraph))
391
392

    # load and preprocess dataset
393
394
395
396
397
398
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
399
400
401
402
    rets1 = main(args, data)
    rets2 = main(args, data)
    for hidden1, hidden2 in zip(rets1, rets2):
        print("hidden: " + str(mx.nd.sum(mx.nd.abs(hidden1 - hidden2)).asnumpy()))