sse_batch.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
import numpy as np
import time
9
import math
10
11
12
13
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
14
from dgl import DGLGraph
15
16
from dgl.data import register_data_args, load_data

17
def gcn_msg(edges):
18
    # TODO should we use concat?
19
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
20

21
def gcn_reduce(nodes):
22
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1) / nodes.mailbox['m'].shape[1]}
23
24

class NodeUpdate(gluon.Block):
25
    def __init__(self, out_feats, activation=None, alpha=0.1, **kwargs):
26
        super(NodeUpdate, self).__init__(**kwargs)
27
28
29
30
31
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        # TODO what is the dimension here?
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

32
33
34
35
36
37
38
39
40
41
42
43
    def forward(self, in_data, hidden_data, accum):
        tmp = mx.nd.concat(in_data, accum, dim=1)
        hidden = self.linear2(self.linear1(tmp))
        return hidden_data * (1 - self.alpha) + self.alpha * hidden

class DGLNodeUpdate(gluon.Block):
    def __init__(self, update):
        super(DGLNodeUpdate, self).__init__()
        self.update = update

    def forward(self, node):
        return {'h1': self.update(node.data['in'], node.data['h'], node.data['accum'])}
44
45
46
47
48

class SSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 dropout,
49
                 activation,
50
51
52
53
                 **kwargs):
        super(SSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = NodeUpdate(n_hidden, activation)
54
        self.dropout = dropout
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    def forward(self, g, vertices):
        if vertices is None:
            deg = mx.nd.expand_dims(g.in_degrees(np.arange(0, g.number_of_nodes())), 1).astype(np.float32)
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            accum = mx.nd.dot(g.adjacency_matrix(), cat) / deg
            return self.layer(feat, g.ndata['h'], accum)
        else:
            deg = mx.nd.expand_dims(g.in_degrees(vertices), 1).astype(np.float32)
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
            feat = g.get_n_repr()['in']
            cat = mx.nd.concat(feat, g.ndata['h'], dim=1)
            slices = mx.nd.take(g.adjacency_matrix(), vertices).as_in_context(cat.context)
            accum = mx.nd.dot(slices, cat) / deg.as_in_context(cat.context)
            vertices = vertices.as_in_context(g.ndata['in'].context)
            return self.layer(mx.nd.take(feat, vertices),
                              mx.nd.take(g.ndata['h'], vertices), accum)

class DGLSSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
                 use_spmv,
                 **kwargs):
        super(DGLSSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = DGLNodeUpdate(NodeUpdate(n_hidden, activation))
        self.dropout = dropout
88
89
90
91
        self.use_spmv = use_spmv

    def forward(self, g, vertices):
        if self.use_spmv:
92
            feat = g.ndata['in']
93
            g.ndata['cat'] = mx.nd.concat(feat, g.ndata['h'], dim=1)
94

95
96
            msg_func = fn.copy_src(src='cat', out='m')
            reduce_func = fn.sum(msg='m', out='accum')
97
98
99
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
100
        deg = mx.nd.expand_dims(g.in_degrees(np.arange(0, g.number_of_nodes())), 1).astype(np.float32)
101
        if vertices is None:
102
103
104
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
105
                g.ndata['accum'] = g.ndata['accum'] / deg
106
107
108
109
110
111
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
                g.apply_nodes(self.layer, vs, inplace=True)
            g.ndata.pop('accum')
112
            return g.get_n_repr()['h1']
113
114
115
116
        else:
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
117
118
                g.ndata['h'] = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
            g.pull(vertices, msg_func, reduce_func, None)
119
120
            if self.use_spmv:
                g.ndata.pop('cat')
121
122
123
                deg = deg.as_in_context(g.ndata['accum'].context)
                g.ndata['accum'] = g.ndata['accum'] / deg
            g.apply_nodes(self.layer, vertices)
124
            g.ndata.pop('accum')
125
            return g.ndata['h1'][vertices.as_in_context(g.ndata['h1'].context)]
126
127

class SSEPredict(gluon.Block):
128
129
130
131
132
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
133
134
135
136
137
138
139
140
141
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

142
143
144
145
146
def copy_to_gpu(subg, ctx):
    frame = subg.ndata
    for key in frame:
        subg.ndata[key] = frame[key].as_in_context(ctx)

147
def main(args, data):
148
149
150
151
152
153
154
155
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
156
157
158
159
160
    train_size = len(labels) * args.train_percent
    train_vs = np.arange(train_size, dtype='int64')
    eval_vs = np.arange(train_size, len(labels), dtype='int64')
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
161
    eval_labels = mx.nd.array(data.labels[eval_vs])
162
163
164
165
166
167
168
169
170
171
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
172
173
174
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
175

176
177
178
179
180
181
182
183
184
185
186
    update_hidden_infer = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
                                             prefix='sse')
    update_hidden_train = DGLSSEUpdateHidden(args.n_hidden, 'relu',
                                             args.update_dropout, args.use_spmv,
                                             prefix='sse')
    if not args.dgl:
        update_hidden_infer = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
        update_hidden_train = SSEUpdateHidden(args.n_hidden, args.update_dropout, 'relu',
                                              prefix='sse')
187

188
189
190
    model_train = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer = SSEPredict(update_hidden_infer, args.n_hidden, args.predict_dropout, prefix='app')
    model_infer.initialize(ctx=mx.cpu(0))
191
    if args.gpu <= 0:
192
        model_train.initialize(ctx=mx.cpu(0))
193
    else:
194
        train_ctxs = []
195
196
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
197
        model_train.initialize(ctx=train_ctxs)
198
199
200
201
202

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
203
    trainer = gluon.Trainer(model_train.collect_params(), 'adam', {'learning_rate': args.lr,
204
205
206
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
207
208
209
210
    all_hidden = update_hidden_infer(g, None)
    g.ndata['h'] = all_hidden
    rets = []
    rets.append(all_hidden)
211
212
213
214
215
216

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
217
        i = 0
218
219
220
        num_batches = len(train_vs) / args.batch_size
        start1 = time.time()
        for subg, seeds in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, g.number_of_nodes(),
Da Zheng's avatar
Da Zheng committed
221
222
223
                neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
                shuffle=True):
            subg.copy_from_parent()
224
225

            losses = []
Da Zheng's avatar
Da Zheng committed
226
227
228
229
230
231
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            subg_seeds = subg.map_to_subgraph_nid(seeds)
            with mx.autograd.record():
232
                logits = model_train(subg, subg_seeds.tousertensor())
Da Zheng's avatar
Da Zheng committed
233
234
235
236
                batch_labels = mx.nd.array(labels[seeds.asnumpy()], ctx=logits.context)
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
237
238
239
240
241
242
            i += 1
            if args.gpu <= 0:
                trainer.step(seeds.shape[0])
                train_loss += loss.asnumpy()[0]
                losses = []
            elif i % args.gpu == 0:
Da Zheng's avatar
Da Zheng committed
243
244
245
246
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
247

248
249
250
251
252
253
254
255
            if i % args.num_parallel_subgraphs == 0:
                end1 = time.time()
                print("process " + str(args.num_parallel_subgraphs)
                        + " subgraphs takes " + str(end1 - start1))
                start1 = end1

            if i > num_batches / 3:
                break
256

257
258
259
260
261
262
263
        # prediction.
        logits = model_infer(g, mx.nd.array(eval_vs, dtype=np.int64))
        eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
        eval_loss = eval_loss.asnumpy()[0]

        # update the inference model.
        infer_params = model_infer.collect_params()
264
265
266
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
267
268
269
270
271

        # Update node embeddings.
        all_hidden = update_hidden_infer(g, None)
        g.ndata['h'] = all_hidden
        rets.append(all_hidden)
272

273
274
275
276
        dur.append(time.time() - t0)
        print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))

277
278
    return rets

279
280
281
282
283
284
285
286
287
288
289
290
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
291
        return mx.nd.contrib.getnnz(self._mat).asnumpy()[0]
292
293
294
295
296
297
298
299
300
301
302

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
        self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0])))
        self.num_labels = 10

303
304
305
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
306
307
308
309
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
310
311
312
313
314
315
316
317
318
319
320
321
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
322
    parser.add_argument("--update-dropout", type=float, default=0,
323
            help="the dropout rate for updating vertex embedding")
324
    parser.add_argument("--predict-dropout", type=float, default=0,
325
326
327
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
328
    parser.add_argument("--use-spmv", action="store_true",
329
            help="use SpMV for faster speed.")
330
    parser.add_argument("--dgl", action="store_true")
331
332
333
334
335
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
    args = parser.parse_args()

    # load and preprocess dataset
336
337
338
339
340
341
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
342
343
344
345
    rets1 = main(args, data)
    rets2 = main(args, data)
    for hidden1, hidden2 in zip(rets1, rets2):
        print("hidden: " + str(mx.nd.sum(mx.nd.abs(hidden1 - hidden2)).asnumpy()))