sse_batch.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
"""
Learning Steady-States of Iterative Algorithms over Graphs
Paper: http://proceedings.mlr.press/v80/dai18a.html

"""
import argparse
import numpy as np
import time
9
import math
10
11
12
13
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
14
from dgl import DGLGraph
15
16
from dgl.data import register_data_args, load_data

17
def gcn_msg(edges):
18
    # TODO should we use concat?
19
    return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
20

21
22
def gcn_reduce(nodes):
    return {'accum': mx.nd.sum(nodes.mailbox['m'], 1)}
23
24

class NodeUpdate(gluon.Block):
25
26
    def __init__(self, out_feats, activation=None, alpha=0.9, **kwargs):
        super(NodeUpdate, self).__init__(**kwargs)
27
28
29
30
31
        self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
        # TODO what is the dimension here?
        self.linear2 = gluon.nn.Dense(out_feats)
        self.alpha = alpha

32
33
34
35
    def forward(self, nodes):
        hidden = mx.nd.concat(nodes.data['in'], nodes.data['accum'], dim=1)
        hidden = self.linear2(self.linear1(hidden))
        return {'h': nodes.data['h'] * (1 - self.alpha) + self.alpha * hidden}
36
37
38
39
40
41

class SSEUpdateHidden(gluon.Block):
    def __init__(self,
                 n_hidden,
                 activation,
                 dropout,
42
43
44
45
46
                 use_spmv,
                 **kwargs):
        super(SSEUpdateHidden, self).__init__(**kwargs)
        with self.name_scope():
            self.layer = NodeUpdate(n_hidden, activation)
47
48
49
50
51
        self.dropout = dropout
        self.use_spmv = use_spmv

    def forward(self, g, vertices):
        if self.use_spmv:
52
53
54
            feat = g.ndata['in']
            h = g.ndata['h']
            g.ndata['cat'] = mx.nd.concat(feat, h, dim=1)
55
56
57
58
59
60
61

            msg_func = fn.copy_src(src='cat', out='tmp')
            reduce_func = fn.sum(msg='tmp', out='accum')
        else:
            msg_func = gcn_msg
            reduce_func = gcn_reduce
        if vertices is None:
62
63
64
65
66
67
68
69
70
71
            g.update_all(msg_func, reduce_func, None)
            if self.use_spmv:
                g.ndata.pop('cat')
            batch_size = 100000
            num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
            for i in range(num_batches):
                vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
                g.apply_nodes(self.layer, vs, inplace=True)
            g.ndata.pop('accum')
            ret = g.ndata['h']
72
73
74
75
        else:
            # We don't need dropout for inference.
            if self.dropout:
                # TODO here we apply dropout on all vertex representation.
76
77
                val = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
                g.ndata['h'] = val
78
            g.pull(vertices, msg_func, reduce_func, self.layer)
79
80
81
82
83
            ctx = g.ndata['h'].context
            ret = mx.nd.take(g.ndata['h'], vertices.tousertensor().as_in_context(ctx))
            if self.use_spmv:
                g.ndata.pop('cat')
            g.ndata.pop('accum')
84
85
86
        return ret

class SSEPredict(gluon.Block):
87
88
89
90
91
    def __init__(self, update_hidden, out_feats, dropout, **kwargs):
        super(SSEPredict, self).__init__(**kwargs)
        with self.name_scope():
            self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
            self.linear2 = gluon.nn.Dense(out_feats)
92
93
94
95
96
97
98
99
100
        self.update_hidden = update_hidden
        self.dropout = dropout

    def forward(self, g, vertices):
        hidden = self.update_hidden(g, vertices)
        if self.dropout:
            hidden = mx.nd.Dropout(hidden, p=self.dropout)
        return self.linear2(self.linear1(hidden))

101
102
103
104
105
def copy_to_gpu(subg, ctx):
    frame = subg.ndata
    for key in frame:
        subg.ndata[key] = frame[key].as_in_context(ctx)

106
def main(args, data):
107
108
109
110
111
112
113
114
    if isinstance(data.features, mx.nd.NDArray):
        features = data.features
    else:
        features = mx.nd.array(data.features)
    if isinstance(data.labels, mx.nd.NDArray):
        labels = data.labels
    else:
        labels = mx.nd.array(data.labels)
115
116
117
118
119
    train_size = len(labels) * args.train_percent
    train_vs = np.arange(train_size, dtype='int64')
    eval_vs = np.arange(train_size, len(labels), dtype='int64')
    print("train size: " + str(len(train_vs)))
    print("eval size: " + str(len(eval_vs)))
Da Zheng's avatar
Da Zheng committed
120
    labels = data.labels
121
122
123
124
125
126
127
128
129
130
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    # create the SSE model
    try:
        graph = data.graph.get_graph()
    except AttributeError:
        graph = data.graph
    g = DGLGraph(graph, readonly=True)
131
132
133
    g.ndata['in'] = features
    g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
            ctx=mx.cpu(0))
134

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    update_hidden_infer = SSEUpdateHidden(args.n_hidden, 'relu',
            args.update_dropout, args.use_spmv, prefix='sse')
    update_hidden_infer.initialize(ctx=mx.cpu(0))

    train_ctxs = []
    update_hidden_train = SSEUpdateHidden(args.n_hidden, 'relu',
            args.update_dropout, args.use_spmv, prefix='sse')
    model = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
    if args.gpu <= 0:
        model.initialize(ctx=mx.cpu(0))
        train_ctxs.append(mx.cpu(0))
    else:
        for i in range(args.gpu):
            train_ctxs.append(mx.gpu(i))
        model.initialize(ctx=train_ctxs)
150
151
152
153
154
155

    # use optimizer
    num_batches = int(g.number_of_nodes() / args.batch_size)
    scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
            args.lr * 10, 0, 0, args.lr/5)
    trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr,
156
157
158
159
        'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))

    # compute vertex embedding.
    update_hidden_infer(g, None)
160
161
162
163
164
165

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        t0 = time.time()
        train_loss = 0
Da Zheng's avatar
Da Zheng committed
166
167
168
169
170
        i = 0
        for subg, seeds in dgl.sampling.NeighborSampler(g, args.batch_size, g.number_of_nodes(),
                neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
                shuffle=True):
            subg.copy_from_parent()
171
172

            losses = []
Da Zheng's avatar
Da Zheng committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
            if args.gpu > 0:
                ctx = mx.gpu(i % args.gpu)
                copy_to_gpu(subg, ctx)

            subg_seeds = subg.map_to_subgraph_nid(seeds)
            with mx.autograd.record():
                logits = model(subg, subg_seeds)
                batch_labels = mx.nd.array(labels[seeds.asnumpy()], ctx=logits.context)
                loss = mx.nd.softmax_cross_entropy(logits, batch_labels)
            loss.backward()
            losses.append(loss)
            i = i + 1
            if i % args.gpu == 0:
                trainer.step(len(seeds) * len(losses))
                for loss in losses:
                    train_loss += loss.asnumpy()[0]
                losses = []
190
191
192
193
194
195

        #logits = model(eval_vs)
        #eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
        #eval_loss = eval_loss.asnumpy()[0]
        eval_loss = 0

196
197
198
199
200
201
202
        # compute vertex embedding.
        infer_params = update_hidden_infer.collect_params()
        for key in infer_params:
            idx = trainer._param2idx[key]
            trainer._kvstore.pull(idx, out=infer_params[key].data())
        update_hidden_infer(g, None)

203
204
205
206
        dur.append(time.time() - t0)
        print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class MXNetGraph(object):
    """A simple graph object that uses scipy matrix."""
    def __init__(self, mat):
        self._mat = mat

    def get_graph(self):
        return self._mat

    def number_of_nodes(self):
        return self._mat.shape[0]

    def number_of_edges(self):
        return mx.nd.contrib.getnnz(self._mat)

class GraphData:
    def __init__(self, csr, num_feats):
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
        csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
        self.graph = MXNetGraph(csr)
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
        self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0])))
        self.num_labels = 10

231
232
233
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
234
235
236
237
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=10,
            help="the number of features")
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
    parser.add_argument("--batch-size", type=int, default=128,
            help="number of vertices in a batch")
    parser.add_argument("--n-epochs", type=int, default=20,
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--warmup", type=int, default=10,
            help="number of iterations to warm up with large learning rate")
    parser.add_argument("--update-dropout", type=float, default=0.5,
            help="the dropout rate for updating vertex embedding")
    parser.add_argument("--predict-dropout", type=float, default=0.5,
            help="the dropout rate for prediction")
    parser.add_argument("--train_percent", type=float, default=0.5,
            help="the percentage of data used for training")
    parser.add_argument("--use-spmv", type=bool, default=False,
            help="use SpMV for faster speed.")
    parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
            help="the number of subgraphs to construct in parallel.")
    args = parser.parse_args()

    # load and preprocess dataset
263
264
265
266
267
268
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        data = GraphData(csr, args.num_feats)
        csr = None
    else:
        data = load_data(args)
269
    main(args, data)