sse_batch.py 15.7 KB
Newer Older
1
2
3
4
5
6
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
Da Zheng's avatar
Da Zheng committed
7
import random
8
9
import numpy as np
import time
10
import math
11
12
13
14
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
15
from dgl import DGLGraph
16
17
from dgl.data import register_data_args, load_data

18
def gcn_msg(edges):
19
    # TODO should we use concat?
20
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
21

22
def gcn_reduce(nodes):
23
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1) / nodes.mailbox['m'].shape[1]}
24
25

class NodeUpdate(gluon.Block):
26
    def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
27
        super(NodeUpdate, self).__init__(**kwargs)
28
29
30
31
32
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        # TODO what is the dimension here?
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

33
34
35
36
37
38
39
40
41
42
43
44
    def forward(self, in_data, hidden_data, accum):
        tmp = mx.nd.concat(in_data, accum, dim=1)
        hidden = self.linear2(self.linear1(tmp))
        return hidden_data * (1 - self.alpha) + self.alpha * hidden

class DGLNodeUpdate(gluon.Block):
    def __init__(self, update):
        super(DGLNodeUpdate, self).__init__()
        self.update = update

    def forward(self, node):
        return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
45
46
47
48
49

class SSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 dropout,
50
                 activation,
51
52
53
54
                 **kwargs):
        super(SSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = NodeUpdate(n_hidden, activation)
55
        self.dropout = dropout
Da Zheng's avatar
Da Zheng committed
56
        self.n_hidden = n_hidden
57
58
59

    def forward(self, g, vertices):
        if vertices is None:
Da Zheng's avatar
Da Zheng committed
60
            deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
61
62
63
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
Da Zheng's avatar
Da Zheng committed
64
65
66
67
68
69
70
71
72
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            ret = mx.nd.empty(shape=(feat.shape[0], self.n_hidden), ctx=feat.context)
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
                ret[vs] = self.layer(mx.nd.take(feat, vs),
                                     mx.nd.take(g.ndata['h'], vs),
                                     mx.nd.take(accum, vs))
            return ret
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        else:
            deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
            accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
            vertices = vertices.as_in_context(g.ndata['in'].context)
            return self.layer(mx.nd.take(feat, vertices),
                              mx.nd.take(g.ndata['h'], vertices), accum)

class DGLSSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
Da Zheng's avatar
Da Zheng committed
93
                 inference,
94
95
96
97
98
                 **kwargs):
        super(DGLSSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
99
        self.use_spmv = use_spmv
Da Zheng's avatar
Da Zheng committed
100
        self.inference = inference
101
102
103

    def forward(self, g, vertices):
        if self.use_spmv:
104
            feat = g.ndata['in']
105
            g.ndata['cat'] = mx.nd.concat(feat, g.ndata['h'], dim=1)
106

107
108
            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
109
110
111
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
Da Zheng's avatar
Da Zheng committed
112
        deg = mx.nd.expand_dims(g.in_degrees(), 1).astype(np.float32)
113
        if vertices is None:
114
115
116
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
117
                g.ndata['accum'] = g.ndata['accum'] / deg
118
119
120
121
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
Da Zheng's avatar
Da Zheng committed
122
                g.apply_nodes(self.layer, vs, inplace=self.inference)
123
            g.ndata.pop('accum')
124
            return g.get_n_repr()['h1']
125
126
127
128
        else:
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
129
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
Da Zheng's avatar
Da Zheng committed
130
131
            g.update_all(msg_func, reduce_func, None)
            ctx = g.ndata['accum'].context
132
133
            if self.use_spmv:
                g.ndata.pop('cat')
Da Zheng's avatar
Da Zheng committed
134
                deg = deg.as_in_context(ctx)
135
                g.ndata['accum'] = g.ndata['accum'] / deg
Da Zheng's avatar
Da Zheng committed
136
            g.apply_nodes(self.layer, vertices, inplace=self.inference)
137
            g.ndata.pop('accum')
Da Zheng's avatar
Da Zheng committed
138
            return mx.nd.take(g.ndata['h1'], vertices.as_in_context(ctx))
139
140

class SSEPredict(gluon.Block):
141
142
143
144
145
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
146
147
148
149
150
151
152
153
154
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

155
156
157
158
159
def copy_to_gpu(subg, ctx):
    frame = subg.ndata
    for key in frame:
        subg.ndata[key] = frame[key].as_in_context(ctx)

Da Zheng's avatar
Da Zheng committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class CachedSubgraph(object):
    def __init__(self, subg, seeds, subg_seeds):
        # We can't cache the input subgraph because it contains node frames
        # and data frames.
        self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
                                subg._graph)
        self.seeds = seeds
        self.subg_seeds = subg_seeds

class CachedSubgraphLoader(object):
    def __init__(self, loader, shuffle):
        self._loader = loader
        self._cached = []
        self._shuffle = shuffle

    def restart(self):
        self._subgraphs = self._cached
        self._gen_subgraph = len(self._subgraphs) == 0
        random.shuffle(self._subgraphs)
        self._cached = []

    def __iter__(self):
        return self

    def __next__(self):
        if len(self._subgraphs) > 0:
            s = self._subgraphs.pop(0)
            subg, seeds, subg_seeds = s.subg, s.seeds, s.subg_seeds
        elif self._gen_subgraph:
            subg, seeds = self._loader.__next__()
            subg_seeds = subg.map_to_subgraph_nid(seeds)
        else:
            raise StopIteration
        self._cached.append(CachedSubgraph(subg, seeds, subg_seeds))
        return subg, seeds, subg_seeds

196
def main(args, data):
197
198
199
200
201
202
203
204
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
205
    train_size = len(labels) * args.train_percent
Da Zheng's avatar
Da Zheng committed
206
207
    train_vs = mx.nd.arange(0, train_size, dtype='int64')
    eval_vs = mx.nd.arange(train_size, len(labels), dtype='int64')
208
209
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
Da Zheng's avatar
Da Zheng committed
210
    eval_labels = mx.nd.take(labels, eval_vs)
211
212
213
214
215
216
217
218
219
    in_feats = features.shape[1]
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
220
221
222
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
223

224
225
    update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
226
                                             inference=True, prefix='sse')
227
228
    update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
Da Zheng's avatar
Da Zheng committed
229
                                             inference=False, prefix='sse')
230
231
232
233
234
    if not args.dgl:
        update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
        update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
235

236
237
238
    model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer.initialize(ctx=mx.cpu(0))
239
    if args.gpu <= 0:
240
        model_train.initialize(ctx=mx.cpu(0))
241
    else:
242
        train_ctxs = []
243
244
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
245
        model_train.initialize(ctx=train_ctxs)
246
247
248
249
250

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
251
    trainer = gluon.Trainer(model_train.collect_params(), 'adam', {'learning_rate': args.lr,
252
253
254
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
255
256
257
258
    all_hidden = update_hidden_infer(g, None)
    g.ndata['h'] = all_hidden
    rets = []
    rets.append(all_hidden)
259

Da Zheng's avatar
Da Zheng committed
260
261
262
263
    if args.neigh_expand <= 0:
        neigh_expand = g.number_of_nodes()
    else:
        neigh_expand = args.neigh_expand
264
265
    # initialize graph
    dur = []
Da Zheng's avatar
Da Zheng committed
266
267
268
269
    sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
            neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
            shuffle=True)
    sampler = CachedSubgraphLoader(sampler, shuffle=True)
270
271
272
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
273
        i = 0
274
275
        num_batches = len(train_vs) / args.batch_size
        start1 = time.time()
Da Zheng's avatar
Da Zheng committed
276
277
        sampler.restart()
        for subg, seeds, subg_seeds in sampler:
Da Zheng's avatar
Da Zheng committed
278
            subg.copy_from_parent()
279
280

            losses = []
Da Zheng's avatar
Da Zheng committed
281
282
283
284
285
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            with mx.autograd.record():
286
                logits = model_train(subg, subg_seeds)
Da Zheng's avatar
Da Zheng committed
287
                batch_labels = mx.nd.take(labels, seeds).as_in_context(logits.context)
Da Zheng's avatar
Da Zheng committed
288
289
290
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
291
292
293
294
295
296
            i += 1
            if args.gpu <= 0:
                trainer.step(seeds.shape[0])
                train_loss += loss.asnumpy()[0]
                losses = []
            elif i % args.gpu == 0:
Da Zheng's avatar
Da Zheng committed
297
298
299
300
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
301

302
303
304
305
306
307
308
309
            if i % args.num_parallel_subgraphs == 0:
                end1 = time.time()
                print("process " + str(args.num_parallel_subgraphs)
                        + " subgraphs takes " + str(end1 - start1))
                start1 = end1

            if i > num_batches / 3:
                break
310

311
        # prediction.
Da Zheng's avatar
Da Zheng committed
312
        logits = model_infer(g, eval_vs)
313
314
315
316
317
        eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
        eval_loss = eval_loss.asnumpy()[0]

        # update the inference model.
        infer_params = model_infer.collect_params()
318
319
320
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
321
322
323
324
325

        # Update node embeddings.
        all_hidden = update_hidden_infer(g, None)
        g.ndata['h'] = all_hidden
        rets.append(all_hidden)
326

327
328
329
330
        dur.append(time.time() - t0)
        print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))

331
332
    return rets

333
334
335
336
337
338
339
340
341
342
343
344
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
345
        return mx.nd.contrib.getnnz(self._mat).asnumpy()[0]
346
347
348
349
350
351
352
353

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
Da Zheng's avatar
Da Zheng committed
354
        self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=10, shape=(csr.shape[0])))
355

356
357
358
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
359
360
361
362
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
363
364
365
366
367
368
369
370
371
372
373
374
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
375
    parser.add_argument("--update-dropout", type=float, default=0,
376
            help="the dropout rate for updating vertex embedding")
377
    parser.add_argument("--predict-dropout", type=float, default=0,
378
379
380
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
381
    parser.add_argument("--use-spmv", action="store_true",
382
            help="use SpMV for faster speed.")
383
    parser.add_argument("--dgl", action="store_true")
384
385
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
Da Zheng's avatar
Da Zheng committed
386
387
    parser.add_argument("--neigh-expand", type=int, default=16,
            help="the number of neighbors to sample.")
388
389
390
    args = parser.parse_args()

    # load and preprocess dataset
391
392
393
394
395
396
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
397
398
399
400
    rets1 = main(args, data)
    rets2 = main(args, data)
    for hidden1, hidden2 in zip(rets1, rets2):
        print("hidden: " + str(mx.nd.sum(mx.nd.abs(hidden1 - hidden2)).asnumpy()))