Unverified Commit bd0e4fa0 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[MXNet][API] move to the new API (#123)

* move gat to the new api.

* fix gcn.

* update sse.

* fix dgl core.

* update sse.

* fix small bugs in dgl core.

* fix mxnet tests.

* retrigger

* address comments and fix more bugs.

* fix

* fix tests.
parent 1eb17bb0
...@@ -18,18 +18,18 @@ from dgl.data import register_data_args, load_data ...@@ -18,18 +18,18 @@ from dgl.data import register_data_args, load_data
def elu(data): def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu') return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(src, edge): def gat_message(edges):
return {'ft' : src['ft'], 'a2' : src['a2']} return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(gluon.Block): class GATReduce(gluon.Block):
def __init__(self, attn_drop): def __init__(self, attn_drop):
super(GATReduce, self).__init__() super(GATReduce, self).__init__()
self.attn_drop = attn_drop self.attn_drop = attn_drop
def forward(self, node, msgs): def forward(self, nodes):
a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1) a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1) a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D) ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention # attention
a = a1 + a2 # shape (B, deg, 1) a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a)) e = mx.nd.softmax(mx.nd.LeakyReLU(a))
...@@ -48,13 +48,13 @@ class GATFinalize(gluon.Block): ...@@ -48,13 +48,13 @@ class GATFinalize(gluon.Block):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim) self.residual_fc = gluon.nn.Dense(hiddendim)
def forward(self, node): def forward(self, nodes):
ret = node['accum'] ret = nodes.data['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(nodes.data['h']) + ret
else: else:
ret = node['h'] + ret ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)} return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block): class GATPrepare(gluon.Block):
......
...@@ -14,11 +14,11 @@ import dgl ...@@ -14,11 +14,11 @@ import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(edge):
return src return {'m': edge.src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(node):
return mx.nd.sum(msgs, 1) return {'accum': mx.nd.sum(node.mailbox['m'], 1)}
class NodeUpdateModule(gluon.Block): class NodeUpdateModule(gluon.Block):
def __init__(self, out_feats, activation=None): def __init__(self, out_feats, activation=None):
...@@ -26,7 +26,7 @@ class NodeUpdateModule(gluon.Block): ...@@ -26,7 +26,7 @@ class NodeUpdateModule(gluon.Block):
self.linear = gluon.nn.Dense(out_feats, activation=activation) self.linear = gluon.nn.Dense(out_feats, activation=activation)
def forward(self, node): def forward(self, node):
return self.linear(node) return {'h': self.linear(node.data['accum'])}
class GCN(gluon.Block): class GCN(gluon.Block):
def __init__(self, def __init__(self,
...@@ -50,14 +50,14 @@ class GCN(gluon.Block): ...@@ -50,14 +50,14 @@ class GCN(gluon.Block):
self.layers.add(NodeUpdateModule(n_classes)) self.layers.add(NodeUpdateModule(n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.ndata['h'] = features
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) val = F.dropout(self.g.ndata['h'], p=self.dropout)
self.g.set_n_repr(val) self.g.ndata['h'] = val
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr() return self.g.ndata.pop('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -6,49 +6,52 @@ Paper: http://proceedings.mlr.press/v80/dai18a.html ...@@ -6,49 +6,52 @@ Paper: http://proceedings.mlr.press/v80/dai18a.html
import argparse import argparse
import numpy as np import numpy as np
import time import time
import math
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl import DGLGraph, utils from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(edges):
# TODO should we use concat? # TODO should we use concat?
return {'m': mx.nd.concat(src['in'], src['h'], dim=1)} return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
def gcn_reduce(node, msgs): def gcn_reduce(nodes):
return {'accum': mx.nd.sum(msgs['m'], 1)} return {'accum': mx.nd.sum(nodes.mailbox['m'], 1)}
class NodeUpdate(gluon.Block): class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, alpha=0.9): def __init__(self, out_feats, activation=None, alpha=0.9, **kwargs):
super(NodeUpdate, self).__init__() super(NodeUpdate, self).__init__(**kwargs)
self.linear1 = gluon.nn.Dense(out_feats, activation=activation) self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
# TODO what is the dimension here? # TODO what is the dimension here?
self.linear2 = gluon.nn.Dense(out_feats) self.linear2 = gluon.nn.Dense(out_feats)
self.alpha = alpha self.alpha = alpha
def forward(self, node): def forward(self, nodes):
tmp = mx.nd.concat(node['in'], node['accum'], dim=1) hidden = mx.nd.concat(nodes.data['in'], nodes.data['accum'], dim=1)
hidden = self.linear2(self.linear1(tmp)) hidden = self.linear2(self.linear1(hidden))
return {'h': node['h'] * (1 - self.alpha) + self.alpha * hidden} return {'h': nodes.data['h'] * (1 - self.alpha) + self.alpha * hidden}
class SSEUpdateHidden(gluon.Block): class SSEUpdateHidden(gluon.Block):
def __init__(self, def __init__(self,
n_hidden, n_hidden,
activation, activation,
dropout, dropout,
use_spmv): use_spmv,
super(SSEUpdateHidden, self).__init__() **kwargs):
super(SSEUpdateHidden, self).__init__(**kwargs)
with self.name_scope():
self.layer = NodeUpdate(n_hidden, activation) self.layer = NodeUpdate(n_hidden, activation)
self.dropout = dropout self.dropout = dropout
self.use_spmv = use_spmv self.use_spmv = use_spmv
def forward(self, g, vertices): def forward(self, g, vertices):
if self.use_spmv: if self.use_spmv:
feat = g.get_n_repr()['in'] feat = g.ndata['in']
h = g.get_n_repr()['h'] h = g.ndata['h']
g.set_n_repr({'cat': mx.nd.concat(feat, h, dim=1)}) g.ndata['cat'] = mx.nd.concat(feat, h, dim=1)
msg_func = fn.copy_src(src='cat', out='tmp') msg_func = fn.copy_src(src='cat', out='tmp')
reduce_func = fn.sum(msg='tmp', out='accum') reduce_func = fn.sum(msg='tmp', out='accum')
...@@ -56,22 +59,34 @@ class SSEUpdateHidden(gluon.Block): ...@@ -56,22 +59,34 @@ class SSEUpdateHidden(gluon.Block):
msg_func = gcn_msg msg_func = gcn_msg
reduce_func = gcn_reduce reduce_func = gcn_reduce
if vertices is None: if vertices is None:
g.update_all(msg_func, reduce_func, self.layer) g.update_all(msg_func, reduce_func, None)
ret = g.get_n_repr()['h'] if self.use_spmv:
g.ndata.pop('cat')
batch_size = 100000
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
g.apply_nodes(self.layer, vs, inplace=True)
g.ndata.pop('accum')
ret = g.ndata['h']
else: else:
# We don't need dropout for inference. # We don't need dropout for inference.
if self.dropout: if self.dropout:
# TODO here we apply dropout on all vertex representation. # TODO here we apply dropout on all vertex representation.
val = mx.nd.Dropout(g.get_n_repr()['h'], p=self.dropout) val = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
g.set_n_repr({'h': val}) g.ndata['h'] = val
g.pull(vertices, msg_func, reduce_func, self.layer) g.pull(vertices, msg_func, reduce_func, self.layer)
ctx = g.get_n_repr()['h'].context ctx = g.ndata['h'].context
ret = mx.nd.take(g.get_n_repr()['h'], vertices.tousertensor().as_in_context(ctx)) ret = mx.nd.take(g.ndata['h'], vertices.tousertensor().as_in_context(ctx))
if self.use_spmv:
g.ndata.pop('cat')
g.ndata.pop('accum')
return ret return ret
class SSEPredict(gluon.Block): class SSEPredict(gluon.Block):
def __init__(self, update_hidden, out_feats, dropout): def __init__(self, update_hidden, out_feats, dropout, **kwargs):
super(SSEPredict, self).__init__() super(SSEPredict, self).__init__(**kwargs)
with self.name_scope():
self.linear1 = gluon.nn.Dense(out_feats, activation='relu') self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
self.linear2 = gluon.nn.Dense(out_feats) self.linear2 = gluon.nn.Dense(out_feats)
self.update_hidden = update_hidden self.update_hidden = update_hidden
...@@ -83,10 +98,11 @@ class SSEPredict(gluon.Block): ...@@ -83,10 +98,11 @@ class SSEPredict(gluon.Block):
hidden = mx.nd.Dropout(hidden, p=self.dropout) hidden = mx.nd.Dropout(hidden, p=self.dropout)
return self.linear2(self.linear1(hidden)) return self.linear2(self.linear1(hidden))
def subgraph_gen(g, seed_vertices): def subgraph_gen(g, seed_vertices, ctxs):
assert len(seed_vertices) % len(ctxs) == 0
vertices = [] vertices = []
for seed in seed_vertices: for seed in seed_vertices:
src, _, _ = g.in_edges(seed) src, _ = g.in_edges(seed)
vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0) vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0)
vs = mx.nd.array(np.unique(vs), dtype=np.int64) vs = mx.nd.array(np.unique(vs), dtype=np.int64)
vertices.append(vs) vertices.append(vs)
...@@ -94,11 +110,22 @@ def subgraph_gen(g, seed_vertices): ...@@ -94,11 +110,22 @@ def subgraph_gen(g, seed_vertices):
nids = [] nids = []
for i, subg in enumerate(subgs): for i, subg in enumerate(subgs):
subg.copy_from_parent() subg.copy_from_parent()
nids.append(subg.map_to_subgraph_nid(utils.toindex(seed_vertices[i]))) nids.append(subg.map_to_subgraph_nid(seed_vertices[i]))
return subgs, nids return subgs, nids
def copy_to_gpu(subg, ctx):
frame = subg.ndata
for key in frame:
subg.ndata[key] = frame[key].as_in_context(ctx)
def main(args, data): def main(args, data):
if isinstance(data.features, mx.nd.NDArray):
features = data.features
else:
features = mx.nd.array(data.features) features = mx.nd.array(data.features)
if isinstance(data.labels, mx.nd.NDArray):
labels = data.labels
else:
labels = mx.nd.array(data.labels) labels = mx.nd.array(data.labels)
train_size = len(labels) * args.train_percent train_size = len(labels) * args.train_percent
train_vs = np.arange(train_size, dtype='int64') train_vs = np.arange(train_size, dtype='int64')
...@@ -111,42 +138,45 @@ def main(args, data): ...@@ -111,42 +138,45 @@ def main(args, data):
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
if args.gpu <= 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
features = features.as_in_context(mx.gpu(0))
train_labels = train_labels.as_in_context(mx.gpu(0))
eval_labels = eval_labels.as_in_context(mx.gpu(0))
ctx = mx.gpu(0)
# create the SSE model # create the SSE model
try: try:
graph = data.graph.get_graph() graph = data.graph.get_graph()
except AttributeError: except AttributeError:
graph = data.graph graph = data.graph
g = DGLGraph(graph, readonly=True) g = DGLGraph(graph, readonly=True)
g.set_n_repr({'in': features, 'h': mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden), g.ndata['in'] = features
ctx=ctx)}) g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
ctx=mx.cpu(0))
update_hidden_infer = SSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv, prefix='sse')
update_hidden_infer.initialize(ctx=mx.cpu(0))
update_hidden = SSEUpdateHidden(args.n_hidden, 'relu', args.update_dropout, args.use_spmv) train_ctxs = []
model = SSEPredict(update_hidden, args.n_hidden, args.predict_dropout) update_hidden_train = SSEUpdateHidden(args.n_hidden, 'relu',
model.initialize(ctx=ctx) args.update_dropout, args.use_spmv, prefix='sse')
model = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
if args.gpu <= 0:
model.initialize(ctx=mx.cpu(0))
train_ctxs.append(mx.cpu(0))
else:
for i in range(args.gpu):
train_ctxs.append(mx.gpu(i))
model.initialize(ctx=train_ctxs)
# use optimizer # use optimizer
num_batches = int(g.number_of_nodes() / args.batch_size) num_batches = int(g.number_of_nodes() / args.batch_size)
scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches, scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
args.lr * 10, 0, 0, args.lr/5) args.lr * 10, 0, 0, args.lr/5)
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr, trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr,
'lr_scheduler': scheduler}) 'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))
# compute vertex embedding.
update_hidden_infer(g, None)
# initialize graph # initialize graph
dur = [] dur = []
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
# compute vertex embedding.
update_hidden(g, None)
t0 = time.time() t0 = time.time()
permute = np.random.permutation(len(train_vs)) permute = np.random.permutation(len(train_vs))
randv = train_vs[permute] randv = train_vs[permute]
...@@ -162,14 +192,27 @@ def main(args, data): ...@@ -162,14 +192,27 @@ def main(args, data):
if len(data) < args.num_parallel_subgraphs: if len(data) < args.num_parallel_subgraphs:
continue continue
subgs, seed_ids = subgraph_gen(g, data) subgs, seed_ids = subgraph_gen(g, data, train_ctxs)
losses = []
i = 0
for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data): for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data):
if args.gpu > 0:
ctx = mx.gpu(i % args.gpu)
copy_to_gpu(subg, ctx)
with mx.autograd.record(): with mx.autograd.record():
logits = model(subg, seed_id) logits = model(subg, seed_id)
if label.context != logits.context:
label = label.as_in_context(logits.context)
loss = mx.nd.softmax_cross_entropy(logits, label) loss = mx.nd.softmax_cross_entropy(logits, label)
loss.backward() loss.backward()
trainer.step(d.shape[0]) losses.append(loss)
i = i + 1
if i % args.gpu == 0:
trainer.step(d.shape[0] * len(subgs))
for loss in losses:
train_loss += loss.asnumpy()[0] train_loss += loss.asnumpy()[0]
losses = []
data = [] data = []
labels = [] labels = []
...@@ -178,13 +221,48 @@ def main(args, data): ...@@ -178,13 +221,48 @@ def main(args, data):
#eval_loss = eval_loss.asnumpy()[0] #eval_loss = eval_loss.asnumpy()[0]
eval_loss = 0 eval_loss = 0
# compute vertex embedding.
infer_params = update_hidden_infer.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
update_hidden_infer(g, None)
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000)) epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))
class MXNetGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return mx.nd.contrib.getnnz(self._mat)
class GraphData:
def __init__(self, csr, num_feats):
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
self.graph = MXNetGraph(csr)
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0])))
self.num_labels = 10
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser) register_data_args(parser)
parser.add_argument("--graph-file", type=str, default="",
help="graph file")
parser.add_argument("--num-feats", type=int, default=10,
help="the number of features")
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="gpu") help="gpu")
parser.add_argument("--lr", type=float, default=1e-3, parser.add_argument("--lr", type=float, default=1e-3,
...@@ -210,5 +288,10 @@ if __name__ == '__main__': ...@@ -210,5 +288,10 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
# load and preprocess dataset # load and preprocess dataset
if args.graph_file != '':
csr = mx.nd.load(args.graph_file)[0]
data = GraphData(csr, args.num_feats)
csr = None
else:
data = load_data(args) data = load_data(args)
main(args, data) main(args, data)
...@@ -81,9 +81,13 @@ class ImmutableGraphIndex(object): ...@@ -81,9 +81,13 @@ class ImmutableGraphIndex(object):
NDArray NDArray
Teh edge id array. Teh edge id array.
""" """
if len(u) == 0 or len(v) == 0:
return [], [], []
ids = mx.nd.contrib.edge_id(self._in_csr, v, u) ids = mx.nd.contrib.edge_id(self._in_csr, v, u)
ids = ids.asnumpy() ids = ids.asnumpy()
return ids[ids >= 0] v = v.asnumpy()
u = u.asnumpy()
return u[ids >= 0], v[ids >= 0], ids[ids >= 0]
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
"""Return the predecessors of the node. """Return the predecessors of the node.
......
...@@ -27,11 +27,11 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -27,11 +27,11 @@ def sparse_matrix(data, index, shape, force_format=False):
raise TypeError('MXNet backend only supports CSR format,' raise TypeError('MXNet backend only supports CSR format,'
' but COO format is forced.') ' but COO format is forced.')
coord = index[1] coord = index[1]
return nd.sparse.csr_matrix((data, (coord[0], coord[1])), shape) return nd.sparse.csr_matrix((data, (coord[0], coord[1])), tuple(shape))
elif fmt == 'csr': elif fmt == 'csr':
indices = index[1] indices = index[1]
indptr = index[2] indptr = index[2]
return nd.sparse.csr_matrix((data, indices, indptr), shape) return nd.sparse.csr_matrix((data, indices, indptr), tuple(shape))
else: else:
raise TypeError('Invalid format: %s.' % fmt) raise TypeError('Invalid format: %s.' % fmt)
...@@ -65,7 +65,7 @@ def sum(input, dim): ...@@ -65,7 +65,7 @@ def sum(input, dim):
return nd.sum(input, axis=dim) return nd.sum(input, axis=dim)
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim) return nd.max(input, axis=dim).asnumpy()[0]
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
...@@ -131,7 +131,7 @@ def nonzero_1d(input): ...@@ -131,7 +131,7 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
# TODO: this isn't an ideal implementation. # TODO: this isn't an ideal implementation.
val = nd.sort(input, is_ascend=True) val = nd.sort(input, axis=None, is_ascend=True)
idx = nd.argsort(input, is_ascend=True) idx = nd.argsort(input, is_ascend=True)
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype='int64')
return val, idx return val, idx
......
...@@ -853,7 +853,7 @@ class DGLGraph(object): ...@@ -853,7 +853,7 @@ class DGLGraph(object):
""" """
self._apply_edge_func = func self._apply_edge_func = func
def apply_nodes(self, func="default", v=ALL): def apply_nodes(self, func="default", v=ALL, inplace=False):
"""Apply the function on the node features. """Apply the function on the node features.
Applying a None function will be ignored. Applying a None function will be ignored.
...@@ -865,7 +865,7 @@ class DGLGraph(object): ...@@ -865,7 +865,7 @@ class DGLGraph(object):
v : int, iterable of int, tensor, optional v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
""" """
self._internal_apply_nodes(v, func) self._internal_apply_nodes(v, func, inplace=inplace)
def apply_edges(self, func="default", edges=ALL): def apply_edges(self, func="default", edges=ALL):
"""Apply the function on the edge features. """Apply the function on the edge features.
...@@ -1464,7 +1464,8 @@ class DGLGraph(object): ...@@ -1464,7 +1464,8 @@ class DGLGraph(object):
edges = F.tensor(edges) edges = F.tensor(edges)
return edges[e_mask] return edges[e_mask]
def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None): def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None,
inplace=False):
"""Internal apply nodes """Internal apply nodes
Parameters Parameters
...@@ -1478,7 +1479,7 @@ class DGLGraph(object): ...@@ -1478,7 +1479,7 @@ class DGLGraph(object):
# Skip none function call. # Skip none function call.
if reduce_accum is not None: if reduce_accum is not None:
# write reduce result back # write reduce result back
self.set_n_repr(reduce_accum, v) self.set_n_repr(reduce_accum, v, inplace=inplace)
return return
# take out current node repr # take out current node repr
curr_repr = self.get_n_repr(v) curr_repr = self.get_n_repr(v)
...@@ -1491,4 +1492,4 @@ class DGLGraph(object): ...@@ -1491,4 +1492,4 @@ class DGLGraph(object):
# merge new node_repr with reduce output # merge new node_repr with reduce output
reduce_accum.update(new_repr) reduce_accum.update(new_repr)
new_repr = reduce_accum new_repr = reduce_accum
self.set_n_repr(new_repr, v) self.set_n_repr(new_repr, v, inplace=inplace)
...@@ -62,6 +62,17 @@ class ImmutableGraphIndex(object): ...@@ -62,6 +62,17 @@ class ImmutableGraphIndex(object):
"""Clear the graph.""" """Clear the graph."""
raise Exception('Immutable graph doesn\'t support clearing up') raise Exception('Immutable graph doesn\'t support clearing up')
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
# Immutable graph doesn't support multi-edge.
return False
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
...@@ -207,7 +218,7 @@ class ImmutableGraphIndex(object): ...@@ -207,7 +218,7 @@ class ImmutableGraphIndex(object):
""" """
u = F.tensor([u], dtype=F.int64) u = F.tensor([u], dtype=F.int64)
v = F.tensor([v], dtype=F.int64) v = F.tensor([v], dtype=F.int64)
id = self._sparse.edge_ids(u, v) _, _, id = self._sparse.edge_ids(u, v)
return utils.toindex(id) return utils.toindex(id)
def edge_ids(self, u, v): def edge_ids(self, u, v):
...@@ -223,12 +234,16 @@ class ImmutableGraphIndex(object): ...@@ -223,12 +234,16 @@ class ImmutableGraphIndex(object):
Returns Returns
------- -------
utils.Index utils.Index
The edge id array. The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
""" """
u = u.tousertensor() u = u.tousertensor()
v = v.tousertensor() v = v.tousertensor()
ids = self._sparse.edge_ids(u, v) u, v, ids = self._sparse.edge_ids(u, v)
return utils.toindex(ids) return utils.toindex(u), utils.toindex(v), utils.toindex(ids)
def in_edges(self, v): def in_edges(self, v):
"""Return the in edges of the node(s). """Return the in edges of the node(s).
......
...@@ -119,4 +119,4 @@ class DGLSubGraph(DGLGraph): ...@@ -119,4 +119,4 @@ class DGLSubGraph(DGLGraph):
self._parent._edge_frame[self._parent_eid])) self._parent._edge_frame[self._parent_eid]))
def map_to_subgraph_nid(self, parent_vids): def map_to_subgraph_nid(self, parent_vids):
return map_to_subgraph_nid(self._graph, parent_vids) return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids))
...@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet' ...@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import scipy.sparse as spsp
D = 5 D = 5
reduce_msg_shapes = set() reduce_msg_shapes = set()
...@@ -26,7 +27,26 @@ def reduce_func(nodes): ...@@ -26,7 +27,26 @@ def reduce_func(nodes):
def apply_node_func(nodes): def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['m']} return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False): def generate_graph(grad=False, readonly=False):
if readonly:
row_idx = []
col_idx = []
for i in range(1, 9):
row_idx.append(0)
col_idx.append(i)
row_idx.append(i)
col_idx.append(9)
row_idx.append(9)
col_idx.append(0)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10))
g = DGLGraph(csr, readonly=True)
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.ndata['h'] = ncol
return g
else:
g = DGLGraph() g = DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
...@@ -121,7 +141,7 @@ def test_batch_setter_getter(): ...@@ -121,7 +141,7 @@ def test_batch_setter_getter():
def test_batch_setter_autograd(): def test_batch_setter_autograd():
with mx.autograd.record(): with mx.autograd.record():
g = generate_graph(grad=True) g = generate_graph(grad=True, readonly=True)
h1 = g.ndata['h'] h1 = g.ndata['h']
h1.attach_grad() h1.attach_grad()
# partial set # partial set
...@@ -153,9 +173,9 @@ def test_batch_send(): ...@@ -153,9 +173,9 @@ def test_batch_send():
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
g.send((u, v)) g.send((u, v))
def test_batch_recv(): def check_batch_recv(readonly):
# basic recv test # basic recv test
g = generate_graph() g = generate_graph(readonly=readonly)
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_reduce_func(reduce_func) g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func) g.register_apply_node_func(apply_node_func)
...@@ -167,8 +187,12 @@ def test_batch_recv(): ...@@ -167,8 +187,12 @@ def test_batch_recv():
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear() #reduce_msg_shapes.clear()
def test_update_routines(): def test_batch_recv():
g = generate_graph() check_batch_recv(True)
check_batch_recv(False)
def check_update_routines(readonly):
g = generate_graph(readonly=readonly)
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_reduce_func(reduce_func) g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func) g.register_apply_node_func(apply_node_func)
...@@ -201,7 +225,21 @@ def test_update_routines(): ...@@ -201,7 +225,21 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)}) assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_reduce_0deg(): def test_update_routines():
check_update_routines(True)
check_update_routines(False)
def check_reduce_0deg(readonly):
if readonly:
row_idx = []
col_idx = []
for i in range(1, 5):
row_idx.append(i)
col_idx.append(0)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5))
g = DGLGraph(csr, readonly=True)
else:
g = DGLGraph() g = DGLGraph()
g.add_nodes(5) g.add_nodes(5)
g.add_edge(1, 0) g.add_edge(1, 0)
...@@ -220,7 +258,20 @@ def test_reduce_0deg(): ...@@ -220,7 +258,20 @@ def test_reduce_0deg():
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy()) assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
def test_pull_0deg(): def test_reduce_0deg():
check_reduce_0deg(True)
check_reduce_0deg(False)
def check_pull_0deg(readonly):
if readonly:
row_idx = []
col_idx = []
row_idx.append(0)
col_idx.append(1)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2))
g = DGLGraph(csr, readonly=True)
else:
g = DGLGraph() g = DGLGraph()
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
...@@ -246,6 +297,10 @@ def test_pull_0deg(): ...@@ -246,6 +297,10 @@ def test_pull_0deg():
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
def test_pull_0deg():
check_pull_0deg(True)
check_pull_0deg(False)
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
......
...@@ -67,7 +67,7 @@ def check_basics(g, ig): ...@@ -67,7 +67,7 @@ def check_basics(g, ig):
assert g.has_edge_between(u, v) == ig.has_edge_between(u, v) assert g.has_edge_between(u, v) == ig.has_edge_between(u, v)
randv = utils.toindex(randv) randv = utils.toindex(randv)
ids = g.edge_ids(randv, randv)[2].tolist() ids = g.edge_ids(randv, randv)[2].tolist()
assert sum(ig.edge_ids(randv, randv).tolist() == ids) == len(ids) assert sum(ig.edge_ids(randv, randv)[2].tolist() == ids) == len(ids)
assert sum(g.has_edges_between(randv, randv).tolist() == ig.has_edges_between(randv, randv).tolist()) == len(randv) assert sum(g.has_edges_between(randv, randv).tolist() == ig.has_edges_between(randv, randv).tolist()) == len(randv)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment