sse_batch.py 16.3 KB
Newer Older
1
2
3
4
5
6
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
Da Zheng's avatar
Da Zheng committed
7
import random
8
9
import numpy as np
import time
10
import math
11
12
13
14
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
15
from dgl import DGLGraph
16
17
from dgl.data import register_data_args, load_data

18
def gcn_msg(edges):
19
    # TODO should we use concat?
20
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
21

22
def gcn_reduce(nodes):
23
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1) / nodes.mailbox['m'].shape[1]}
24
25

class NodeUpdate(gluon.Block):
26
    def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
27
        super(NodeUpdate, self).__init__(**kwargs)
28
29
30
31
32
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        # TODO what is the dimension here?
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

33
34
35
36
37
38
39
40
41
42
43
44
    def forward(self, in_data, hidden_data, accum):
        tmp = mx.nd.concat(in_data, accum, dim=1)
        hidden = self.linear2(self.linear1(tmp))
        return hidden_data * (1 - self.alpha) + self.alpha * hidden

class DGLNodeUpdate(gluon.Block):
    def __init__(self, update):
        super(DGLNodeUpdate, self).__init__()
        self.update = update

    def forward(self, node):
        return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
45
46
47
48
49

class SSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 dropout,
50
                 activation,
51
52
53
54
                 **kwargs):
        super(SSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = NodeUpdate(n_hidden, activation)
55
        self.dropout = dropout
Da Zheng's avatar
Da Zheng committed
56
        self.n_hidden = n_hidden
57
58
59

    def forward(self, g, vertices):
        if vertices is None:
Da Zheng's avatar
Da Zheng committed
60
            deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
61
62
63
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
Da Zheng's avatar
Da Zheng committed
64
65
66
67
68
69
70
71
72
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
                ret[vs] = self.layer(mx.nd.take(feat, vs),
                                     mx.nd.take(g.ndata['h'], vs),
                                     mx.nd.take(accum, vs))
            return ret
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        else:
            deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
            accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
            vertices = vertices.as_in_context(g.ndata['in'].context)
            return self.layer(mx.nd.take(feat, vertices),
                              mx.nd.take(g.ndata['h'], vertices), accum)

class DGLSSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
Da Zheng's avatar
Da Zheng committed
93
                 inference,
94
95
96
97
98
                 **kwargs):
        super(DGLSSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
99
        self.use_spmv = use_spmv
Da Zheng's avatar
Da Zheng committed
100
        self.inference = inference
101
102
103

    def forward(self, g, vertices):
        if self.use_spmv:
104
            feat = g.ndata['in']
105
            g.ndata['cat'] = mx.nd.concat(feat, g.ndata['h'], dim=1)
106

107
108
            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
109
110
111
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
Da Zheng's avatar
Da Zheng committed
112
        deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
113
        if vertices is None:
114
115
116
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
117
                g.ndata['accum'] = g.ndata['accum'] / deg
118
119
120
121
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
Da Zheng's avatar
Da Zheng committed
122
                g.apply_nodes(self.layer, vs, inplace=self.inference)
123
            g.ndata.pop('accum')
124
            return g.get_n_repr()['h1']
125
126
127
128
        else:
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
129
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
Da Zheng's avatar
Da Zheng committed
130
131
            g.update_all(msg_func, reduce_func, None)
            ctx = g.ndata['accum'].context
132
133
            if self.use_spmv:
                g.ndata.pop('cat')
Da Zheng's avatar
Da Zheng committed
134
                deg = deg.as_in_context(ctx)
135
                g.ndata['accum'] = g.ndata['accum'] / deg
Da Zheng's avatar
Da Zheng committed
136
            g.apply_nodes(self.layer, vertices, inplace=self.inference)
137
            g.ndata.pop('accum')
Da Zheng's avatar
Da Zheng committed
138
            return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
139
140

class SSEPredict(gluon.Block):
141
142
143
144
145
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
146
147
148
149
150
151
152
153
154
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

155
156
157
158
159
def copy_to_gpu(subg, ctx):
    frame = subg.ndata
    for key in frame:
        subg.ndata[key] = frame[key].as_in_context(ctx)

Da Zheng's avatar
Da Zheng committed
160
class CachedSubgraph(object):
161
    def __init__(self, subg, seeds):
Da Zheng's avatar
Da Zheng committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        # We can't cache the input subgraph because it contains node frames
        # and data frames.
        self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
                                subg._graph)
        self.seeds = seeds

class CachedSubgraphLoader(object):
    def __init__(self, loader, shuffle):
        self._loader = loader
        self._cached = []
        self._shuffle = shuffle

    def restart(self):
        self._subgraphs = self._cached
        self._gen_subgraph = len(self._subgraphs) == 0
        random.shuffle(self._subgraphs)
        self._cached = []

    def __iter__(self):
        return self

    def __next__(self):
        if len(self._subgraphs) > 0:
            s = self._subgraphs.pop(0)
186
            subg, seeds = s.subg, s.seeds
Da Zheng's avatar
Da Zheng committed
187
188
189
190
        elif self._gen_subgraph:
            subg, seeds = self._loader.__next__()
        else:
            raise StopIteration
191
192
        self._cached.append(CachedSubgraph(subg, seeds))
        return subg, seeds
Da Zheng's avatar
Da Zheng committed
193

194
def main(args, data):
195
196
197
198
199
200
201
202
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
203
    train_size = len(labels) * args.train_percent
Da Zheng's avatar
Da Zheng committed
204
205
    train_vs = mx.nd.arange(0, train_size, dtype='int64')
    eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')
206
207
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
Da Zheng's avatar
Da Zheng committed
208
    eval_labels = mx.nd.take(labels, eval_vs)
209
210
211
212
213
214
215
216
217
    in_feats = features.shape[1]
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
218
219
220
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
221

222
223
    update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
224
                                             inference=True, prefix='sse')
225
226
    update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
227
                                             inference=False, prefix='sse')
228
229
230
231
232
    if not args.dgl:
        update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
        update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
233

234
235
236
    model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer.initialize(ctx=mx.cpu(0))
237
    if args.gpu <= 0:
238
        model_train.initialize(ctx=mx.cpu(0))
239
    else:
240
        train_ctxs = []
241
242
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
243
        model_train.initialize(ctx=train_ctxs)
244
245
246
247
248

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
249
    trainer = gluon.Trainer(model_train.collect_params(), 'adam', {'learning_rate': args.lr,
250
251
252
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
253
254
255
256
    all_hidden = update_hidden_infer(g, None)
    g.ndata['h'] = all_hidden
    rets = []
    rets.append(all_hidden)
257

Da Zheng's avatar
Da Zheng committed
258
259
260
261
    if args.neigh_expand <= 0:
        neigh_expand = g.number_of_nodes()
    else:
        neigh_expand = args.neigh_expand
262
263
    # initialize graph
    dur = []
Da Zheng's avatar
Da Zheng committed
264
265
    sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
            neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
266
            shuffle=True, return_seed_id=True)
267
268
    if args.cache_subgraph:
        sampler = CachedSubgraphLoader(sampler, shuffle=True)
269
270
271
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
272
        i = 0
273
274
        num_batches = len(train_vs) / args.batch_size
        start1 = time.time()
275
276
        for subg, aux_infos in sampler:
            seeds = aux_infos['seeds']
277
            subg_seeds = subg.map_to_subgraph_nid(seeds)
Da Zheng's avatar
Da Zheng committed
278
            subg.copy_from_parent()
279
280

            losses = []
Da Zheng's avatar
Da Zheng committed
281
282
283
284
285
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            with mx.autograd.record():
286
                logits = model_train(subg, subg_seeds)
Da Zheng's avatar
Da Zheng committed
287
                batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
Da Zheng's avatar
Da Zheng committed
288
289
290
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
291
292
293
294
295
296
            i += 1
            if args.gpu <= 0:
                trainer.step(seeds.shape[0])
                train_loss += loss.asnumpy()[0]
                losses = []
            elif i % args.gpu == 0:
Da Zheng's avatar
Da Zheng committed
297
298
299
300
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
301

302
303
304
305
306
307
308
309
            if i % args.num_parallel_subgraphs == 0:
                end1 = time.time()
                print("process " + str(args.num_parallel_subgraphs)
                        + " subgraphs takes " + str(end1 - start1))
                start1 = end1

            if i > num_batches / 3:
                break
310

311
312
313
314
315
316
        if args.cache_subgraph:
            sampler.restart()
        else:
            sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
                                                           neighbor_type='in',
                                                           num_workers=args.num_parallel_subgraphs,
317
318
                                                           seed_nodes=train_vs, shuffle=True,
                                                           return_seed_id=True)
319

320
        # prediction.
Da Zheng's avatar
Da Zheng committed
321
        logits = model_infer(g, eval_vs)
322
323
324
325
326
        eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
        eval_loss = eval_loss.asnumpy()[0]

        # update the inference model.
        infer_params = model_infer.collect_params()
327
328
329
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
330
331
332
333
334

        # Update node embeddings.
        all_hidden = update_hidden_infer(g, None)
        g.ndata['h'] = all_hidden
        rets.append(all_hidden)
335

336
337
338
339
        dur.append(time.time() - t0)
        print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))

340
341
    return rets

342
343
344
345
346
347
348
349
350
351
352
353
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
354
        return mx.nd.contrib.getnnz(self._mat).asnumpy()[0]
355
356
357
358
359
360
361
362

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
Da Zheng's avatar
Da Zheng committed
363
        self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
364

365
366
367
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
368
369
370
371
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
372
373
374
375
376
377
378
379
380
381
382
383
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
384
    parser.add_argument("--update-dropout", type=float, default=0,
385
            help="the dropout rate for updating vertex embedding")
386
    parser.add_argument("--predict-dropout", type=float, default=0,
387
388
389
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
390
    parser.add_argument("--use-spmv", action="store_true",
391
            help="use SpMV for faster speed.")
392
    parser.add_argument("--dgl", action="store_true")
393
    parser.add_argument("--cache-subgraph", default=False, action="store_false")
394
395
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
Da Zheng's avatar
Da Zheng committed
396
397
    parser.add_argument("--neigh-expand", type=int, default=16,
            help="the number of neighbors to sample.")
398
    args = parser.parse_args()
399
    print("cache: " + str(args.cache_subgraph))
400
401

    # load and preprocess dataset
402
403
404
405
406
407
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
408
409
410
411
    rets1 = main(args, data)
    rets2 = main(args, data)
    for hidden1, hidden2 in zip(rets1, rets2):
        print("hidden: " + str(mx.nd.sum(mx.nd.abs(hidden1 - hidden2)).asnumpy()))