sse_batch.py 16.6 KB
Newer Older
1
2
3
4
5
6
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
Da Zheng's avatar
Da Zheng committed
7
import random
8
9
import numpy as np
import time
10
import math
11
12
13
14
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
15
from dgl import DGLGraph
16
17
from dgl.data import register_data_args, load_data

18
def gcn_msg(edges):
19
    # TODO should we use concat?
20
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
21

22
def gcn_reduce(nodes):
23
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1) / nodes.mailbox['m'].shape[1]}
24
25

class NodeUpdate(gluon.Block):
26
    def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
27
        super(NodeUpdate, self).__init__(**kwargs)
28
29
30
31
32
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        # TODO what is the dimension here?
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

33
34
35
36
37
38
39
40
41
42
43
44
    def forward(self, in_data, hidden_data, accum):
        tmp = mx.nd.concat(in_data, accum, dim=1)
        hidden = self.linear2(self.linear1(tmp))
        return hidden_data * (1 - self.alpha) + self.alpha * hidden

class DGLNodeUpdate(gluon.Block):
    def __init__(self, update):
        super(DGLNodeUpdate, self).__init__()
        self.update = update

    def forward(self, node):
        return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
45
46
47
48
49

class SSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 dropout,
50
                 activation,
51
52
53
54
                 **kwargs):
        super(SSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = NodeUpdate(n_hidden, activation)
55
        self.dropout = dropout
Da Zheng's avatar
Da Zheng committed
56
        self.n_hidden = n_hidden
57
58
59

    def forward(self, g, vertices):
        if vertices is None:
Da Zheng's avatar
Da Zheng committed
60
            deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
61
62
63
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
Da Zheng's avatar
Da Zheng committed
64
65
66
67
68
69
70
71
72
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
                ret[vs] = self.layer(mx.nd.take(feat, vs),
                                     mx.nd.take(g.ndata['h'], vs),
                                     mx.nd.take(accum, vs))
            return ret
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        else:
            deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
            accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
            vertices = vertices.as_in_context(g.ndata['in'].context)
            return self.layer(mx.nd.take(feat, vertices),
                              mx.nd.take(g.ndata['h'], vertices), accum)

class DGLSSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
Da Zheng's avatar
Da Zheng committed
93
                 inference,
94
95
96
97
98
                 **kwargs):
        super(DGLSSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
99
        self.use_spmv = use_spmv
Da Zheng's avatar
Da Zheng committed
100
        self.inference = inference
101
102
103

    def forward(self, g, vertices):
        if self.use_spmv:
104
            feat = g.ndata['in']
105
            g.ndata['cat'] = mx.nd.concat(feat, g.ndata['h'], dim=1)
106

107
108
            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
109
110
111
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
Da Zheng's avatar
Da Zheng committed
112
        deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
113
        if vertices is None:
114
115
116
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
117
                g.ndata['accum'] = g.ndata['accum'] / deg
118
119
120
121
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
Da Zheng's avatar
Da Zheng committed
122
                g.apply_nodes(self.layer, vs, inplace=self.inference)
123
            g.ndata.pop('accum')
124
            return g.get_n_repr()['h1']
125
126
127
128
        else:
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
129
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
Da Zheng's avatar
Da Zheng committed
130
131
            g.update_all(msg_func, reduce_func, None)
            ctx = g.ndata['accum'].context
132
133
            if self.use_spmv:
                g.ndata.pop('cat')
Da Zheng's avatar
Da Zheng committed
134
                deg = deg.as_in_context(ctx)
135
                g.ndata['accum'] = g.ndata['accum'] / deg
Da Zheng's avatar
Da Zheng committed
136
            g.apply_nodes(self.layer, vertices, inplace=self.inference)
137
            g.ndata.pop('accum')
Da Zheng's avatar
Da Zheng committed
138
            return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
139
140

class SSEPredict(gluon.Block):
141
142
143
144
145
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
146
147
148
149
150
151
152
153
154
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

155
156
157
158
159
def copy_to_gpu(subg, ctx):
    frame = subg.ndata
    for key in frame:
        subg.ndata[key] = frame[key].as_in_context(ctx)

Da Zheng's avatar
Da Zheng committed
160
class CachedSubgraph(object):
161
    def __init__(self, subg, seeds):
Da Zheng's avatar
Da Zheng committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # We can't cache the input subgraph because it contains node frames
        # and data frames.
        self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
                                subg._graph)
        self.seeds = seeds

class CachedSubgraphLoader(object):
    def __init__(self, loader, shuffle):
        self._loader = loader
        self._cached = []
        self._shuffle = shuffle

    def restart(self):
        self._subgraphs = self._cached
        self._gen_subgraph = len(self._subgraphs) == 0
        random.shuffle(self._subgraphs)
        self._cached = []

    def __iter__(self):
        return self

    def __next__(self):
        if len(self._subgraphs) > 0:
            s = self._subgraphs.pop(0)
186
            subg, seeds = s.subg, s.seeds
Da Zheng's avatar
Da Zheng committed
187
188
189
190
        elif self._gen_subgraph:
            subg, seeds = self._loader.__next__()
        else:
            raise StopIteration
191
192
        self._cached.append(CachedSubgraph(subg, seeds))
        return subg, seeds
Da Zheng's avatar
Da Zheng committed
193

194
def main(args, data):
195
196
197
198
199
200
201
202
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
203
204
205
206
207
208
209
210
    if data.train_mask is not None:
        train_vs = mx.nd.array(np.nonzero(data.train_mask)[0], dtype='int64')
        eval_vs = mx.nd.array(np.nonzero(data.train_mask == 0)[0], dtype='int64')
    else:
        train_size = len(labels) * args.train_percent
        train_vs = mx.nd.arange(0, train_size, dtype='int64')
        eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')

211
212
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
Da Zheng's avatar
Da Zheng committed
213
    eval_labels = mx.nd.take(labels, eval_vs)
214
215
216
217
218
219
220
221
222
    in_feats = features.shape[1]
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
223
224
225
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
226

227
228
    update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
229
                                             inference=True, prefix='sse')
230
231
    update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
232
                                             inference=False, prefix='sse')
233
234
235
236
237
    if not args.dgl:
        update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
        update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
238

239
240
241
    model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer.initialize(ctx=mx.cpu(0))
242
    if args.gpu <= 0:
243
        model_train.initialize(ctx=mx.cpu(0))
244
    else:
245
        train_ctxs = []
246
247
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
248
        model_train.initialize(ctx=train_ctxs)
249
250
251
252
253

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
254
    trainer = gluon.Trainer(model_train.collect_params(), 'adam', {'learning_rate': args.lr,
255
256
257
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
258
259
260
261
    all_hidden = update_hidden_infer(g, None)
    g.ndata['h'] = all_hidden
    rets = []
    rets.append(all_hidden)
262

Da Zheng's avatar
Da Zheng committed
263
264
265
266
    if args.neigh_expand <= 0:
        neigh_expand = g.number_of_nodes()
    else:
        neigh_expand = args.neigh_expand
267
268
    # initialize graph
    dur = []
Da Zheng's avatar
Da Zheng committed
269
270
    sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
            neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
271
            shuffle=True, return_seed_id=True)
272
273
    if args.cache_subgraph:
        sampler = CachedSubgraphLoader(sampler, shuffle=True)
274
275
276
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
277
        i = 0
278
279
        num_batches = len(train_vs) / args.batch_size
        start1 = time.time()
280
281
        for subg, aux_infos in sampler:
            seeds = aux_infos['seeds']
282
            subg_seeds = subg.map_to_subgraph_nid(seeds)
Da Zheng's avatar
Da Zheng committed
283
            subg.copy_from_parent()
284
285

            losses = []
Da Zheng's avatar
Da Zheng committed
286
287
288
289
290
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            with mx.autograd.record():
291
                logits = model_train(subg, subg_seeds)
Da Zheng's avatar
Da Zheng committed
292
                batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
Da Zheng's avatar
Da Zheng committed
293
294
295
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
296
297
298
299
300
301
            i += 1
            if args.gpu <= 0:
                trainer.step(seeds.shape[0])
                train_loss += loss.asnumpy()[0]
                losses = []
            elif i % args.gpu == 0:
Da Zheng's avatar
Da Zheng committed
302
303
304
305
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
306

307
308
309
310
311
312
            if i % args.num_parallel_subgraphs == 0:
                end1 = time.time()
                print("process " + str(args.num_parallel_subgraphs)
                        + " subgraphs takes " + str(end1 - start1))
                start1 = end1

313
314
315
316
317
318
        if args.cache_subgraph:
            sampler.restart()
        else:
            sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
                                                           neighbor_type='in',
                                                           num_workers=args.num_parallel_subgraphs,
319
320
                                                           seed_nodes=train_vs, shuffle=True,
                                                           return_seed_id=True)
321

322
        # test set accuracy
Da Zheng's avatar
Da Zheng committed
323
        logits = model_infer(g, eval_vs)
324
325
326
327
        y_bar = mx.nd.argmax(logits, axis=1)
        y = eval_labels
        accuracy = mx.nd.sum(y_bar == y) / len(y)
        accuracy = accuracy.asnumpy()[0]
328
329
330

        # update the inference model.
        infer_params = model_infer.collect_params()
331
332
333
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
334
335
336
337
338

        # Update node embeddings.
        all_hidden = update_hidden_infer(g, None)
        g.ndata['h'] = all_hidden
        rets.append(all_hidden)
339

340
        dur.append(time.time() - t0)
341
342
        print("Epoch {:05d} | Train Loss {:.4f} | Test Accuracy {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, accuracy, np.mean(dur), n_edges / np.mean(dur) / 1000))
343

344
345
    return rets

346
347
348
349
350
351
352
353
354
355
356
357
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
358
        return mx.nd.contrib.getnnz(self._mat).asnumpy()[0]
359
360
361
362
363
364
365
366

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
Da Zheng's avatar
Da Zheng committed
367
        self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
368
        self.train_mask = None
369

370
371
372
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
373
374
375
376
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
377
378
379
380
381
382
383
384
385
386
387
388
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
389
    parser.add_argument("--update-dropout", type=float, default=0,
390
            help="the dropout rate for updating vertex embedding")
391
    parser.add_argument("--predict-dropout", type=float, default=0,
392
393
394
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
395
    parser.add_argument("--use-spmv", action="store_true",
396
            help="use SpMV for faster speed.")
397
    parser.add_argument("--dgl", action="store_true")
398
    parser.add_argument("--cache-subgraph", default=False, action="store_false")
399
400
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
Da Zheng's avatar
Da Zheng committed
401
402
    parser.add_argument("--neigh-expand", type=int, default=16,
            help="the number of neighbors to sample.")
403
    args = parser.parse_args()
404
    print("cache: " + str(args.cache_subgraph))
405
406

    # load and preprocess dataset
407
408
409
410
411
412
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
413
414
415
416
    rets1 = main(args, data)
    rets2 = main(args, data)
    for hidden1, hidden2 in zip(rets1, rets2):
        print("hidden: " + str(mx.nd.sum(mx.nd.abs(hidden1 - hidden2)).asnumpy()))