Unverified Commit bd0e4fa0 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[MXNet][API] move to the new API (#123)

* move gat to the new api.

* fix gcn.

* update sse.

* fix dgl core.

* update sse.

* fix small bugs in dgl core.

* fix mxnet tests.

* retrigger

* address comments and fix more bugs.

* fix

* fix tests.
parent 1eb17bb0
......@@ -18,18 +18,18 @@ from dgl.data import register_data_args, load_data
def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(gluon.Block):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
def forward(self, nodes):
a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a))
......@@ -48,13 +48,13 @@ class GATFinalize(gluon.Block):
if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim)
def forward(self, node):
ret = node['accum']
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = node['h'] + ret
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block):
......
......@@ -14,11 +14,11 @@ import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src
def gcn_msg(edge):
return {'m': edge.src['h']}
def gcn_reduce(node, msgs):
return mx.nd.sum(msgs, 1)
def gcn_reduce(node):
return {'accum': mx.nd.sum(node.mailbox['m'], 1)}
class NodeUpdateModule(gluon.Block):
def __init__(self, out_feats, activation=None):
......@@ -26,7 +26,7 @@ class NodeUpdateModule(gluon.Block):
self.linear = gluon.nn.Dense(out_feats, activation=activation)
def forward(self, node):
return self.linear(node)
return {'h': self.linear(node.data['accum'])}
class GCN(gluon.Block):
def __init__(self,
......@@ -50,14 +50,14 @@ class GCN(gluon.Block):
self.layers.add(NodeUpdateModule(n_classes))
def forward(self, features):
self.g.set_n_repr(features)
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
val = F.dropout(self.g.ndata['h'], p=self.dropout)
self.g.ndata['h'] = val
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr()
return self.g.ndata.pop('h')
def main(args):
# load and preprocess dataset
......
......@@ -6,49 +6,52 @@ Paper: http://proceedings.mlr.press/v80/dai18a.html
import argparse
import numpy as np
import time
import math
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
from dgl import DGLGraph, utils
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
def gcn_msg(edges):
# TODO should we use concat?
return {'m': mx.nd.concat(src['in'], src['h'], dim=1)}
return {'m': mx.nd.concat(edges.src['in'], edges.src['h'], dim=1)}
def gcn_reduce(node, msgs):
return {'accum': mx.nd.sum(msgs['m'], 1)}
def gcn_reduce(nodes):
return {'accum': mx.nd.sum(nodes.mailbox['m'], 1)}
class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, alpha=0.9):
super(NodeUpdate, self).__init__()
def __init__(self, out_feats, activation=None, alpha=0.9, **kwargs):
super(NodeUpdate, self).__init__(**kwargs)
self.linear1 = gluon.nn.Dense(out_feats, activation=activation)
# TODO what is the dimension here?
self.linear2 = gluon.nn.Dense(out_feats)
self.alpha = alpha
def forward(self, node):
tmp = mx.nd.concat(node['in'], node['accum'], dim=1)
hidden = self.linear2(self.linear1(tmp))
return {'h': node['h'] * (1 - self.alpha) + self.alpha * hidden}
def forward(self, nodes):
hidden = mx.nd.concat(nodes.data['in'], nodes.data['accum'], dim=1)
hidden = self.linear2(self.linear1(hidden))
return {'h': nodes.data['h'] * (1 - self.alpha) + self.alpha * hidden}
class SSEUpdateHidden(gluon.Block):
def __init__(self,
n_hidden,
activation,
dropout,
use_spmv):
super(SSEUpdateHidden, self).__init__()
self.layer = NodeUpdate(n_hidden, activation)
use_spmv,
**kwargs):
super(SSEUpdateHidden, self).__init__(**kwargs)
with self.name_scope():
self.layer = NodeUpdate(n_hidden, activation)
self.dropout = dropout
self.use_spmv = use_spmv
def forward(self, g, vertices):
if self.use_spmv:
feat = g.get_n_repr()['in']
h = g.get_n_repr()['h']
g.set_n_repr({'cat': mx.nd.concat(feat, h, dim=1)})
feat = g.ndata['in']
h = g.ndata['h']
g.ndata['cat'] = mx.nd.concat(feat, h, dim=1)
msg_func = fn.copy_src(src='cat', out='tmp')
reduce_func = fn.sum(msg='tmp', out='accum')
......@@ -56,24 +59,36 @@ class SSEUpdateHidden(gluon.Block):
msg_func = gcn_msg
reduce_func = gcn_reduce
if vertices is None:
g.update_all(msg_func, reduce_func, self.layer)
ret = g.get_n_repr()['h']
g.update_all(msg_func, reduce_func, None)
if self.use_spmv:
g.ndata.pop('cat')
batch_size = 100000
num_batches = int(math.ceil(g.number_of_nodes() / batch_size))
for i in range(num_batches):
vs = mx.nd.arange(i * batch_size, min((i + 1) * batch_size, g.number_of_nodes()), dtype=np.int64)
g.apply_nodes(self.layer, vs, inplace=True)
g.ndata.pop('accum')
ret = g.ndata['h']
else:
# We don't need dropout for inference.
if self.dropout:
# TODO here we apply dropout on all vertex representation.
val = mx.nd.Dropout(g.get_n_repr()['h'], p=self.dropout)
g.set_n_repr({'h': val})
val = mx.nd.Dropout(g.ndata['h'], p=self.dropout)
g.ndata['h'] = val
g.pull(vertices, msg_func, reduce_func, self.layer)
ctx = g.get_n_repr()['h'].context
ret = mx.nd.take(g.get_n_repr()['h'], vertices.tousertensor().as_in_context(ctx))
ctx = g.ndata['h'].context
ret = mx.nd.take(g.ndata['h'], vertices.tousertensor().as_in_context(ctx))
if self.use_spmv:
g.ndata.pop('cat')
g.ndata.pop('accum')
return ret
class SSEPredict(gluon.Block):
def __init__(self, update_hidden, out_feats, dropout):
super(SSEPredict, self).__init__()
self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
self.linear2 = gluon.nn.Dense(out_feats)
def __init__(self, update_hidden, out_feats, dropout, **kwargs):
super(SSEPredict, self).__init__(**kwargs)
with self.name_scope():
self.linear1 = gluon.nn.Dense(out_feats, activation='relu')
self.linear2 = gluon.nn.Dense(out_feats)
self.update_hidden = update_hidden
self.dropout = dropout
......@@ -83,10 +98,11 @@ class SSEPredict(gluon.Block):
hidden = mx.nd.Dropout(hidden, p=self.dropout)
return self.linear2(self.linear1(hidden))
def subgraph_gen(g, seed_vertices):
def subgraph_gen(g, seed_vertices, ctxs):
assert len(seed_vertices) % len(ctxs) == 0
vertices = []
for seed in seed_vertices:
src, _, _ = g.in_edges(seed)
src, _ = g.in_edges(seed)
vs = np.concatenate((src.asnumpy(), seed.asnumpy()), axis=0)
vs = mx.nd.array(np.unique(vs), dtype=np.int64)
vertices.append(vs)
......@@ -94,12 +110,23 @@ def subgraph_gen(g, seed_vertices):
nids = []
for i, subg in enumerate(subgs):
subg.copy_from_parent()
nids.append(subg.map_to_subgraph_nid(utils.toindex(seed_vertices[i])))
nids.append(subg.map_to_subgraph_nid(seed_vertices[i]))
return subgs, nids
def copy_to_gpu(subg, ctx):
frame = subg.ndata
for key in frame:
subg.ndata[key] = frame[key].as_in_context(ctx)
def main(args, data):
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
if isinstance(data.features, mx.nd.NDArray):
features = data.features
else:
features = mx.nd.array(data.features)
if isinstance(data.labels, mx.nd.NDArray):
labels = data.labels
else:
labels = mx.nd.array(data.labels)
train_size = len(labels) * args.train_percent
train_vs = np.arange(train_size, dtype='int64')
eval_vs = np.arange(train_size, len(labels), dtype='int64')
......@@ -111,42 +138,45 @@ def main(args, data):
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu <= 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
features = features.as_in_context(mx.gpu(0))
train_labels = train_labels.as_in_context(mx.gpu(0))
eval_labels = eval_labels.as_in_context(mx.gpu(0))
ctx = mx.gpu(0)
# create the SSE model
try:
graph = data.graph.get_graph()
except AttributeError:
graph = data.graph
g = DGLGraph(graph, readonly=True)
g.set_n_repr({'in': features, 'h': mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
ctx=ctx)})
g.ndata['in'] = features
g.ndata['h'] = mx.nd.random.normal(shape=(g.number_of_nodes(), args.n_hidden),
ctx=mx.cpu(0))
update_hidden = SSEUpdateHidden(args.n_hidden, 'relu', args.update_dropout, args.use_spmv)
model = SSEPredict(update_hidden, args.n_hidden, args.predict_dropout)
model.initialize(ctx=ctx)
update_hidden_infer = SSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv, prefix='sse')
update_hidden_infer.initialize(ctx=mx.cpu(0))
train_ctxs = []
update_hidden_train = SSEUpdateHidden(args.n_hidden, 'relu',
args.update_dropout, args.use_spmv, prefix='sse')
model = SSEPredict(update_hidden_train, args.n_hidden, args.predict_dropout, prefix='app')
if args.gpu <= 0:
model.initialize(ctx=mx.cpu(0))
train_ctxs.append(mx.cpu(0))
else:
for i in range(args.gpu):
train_ctxs.append(mx.gpu(i))
model.initialize(ctx=train_ctxs)
# use optimizer
num_batches = int(g.number_of_nodes() / args.batch_size)
scheduler = mx.lr_scheduler.CosineScheduler(args.n_epochs * num_batches,
args.lr * 10, 0, 0, args.lr/5)
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr,
'lr_scheduler': scheduler})
'lr_scheduler': scheduler}, kvstore=mx.kv.create('device'))
# compute vertex embedding.
update_hidden_infer(g, None)
# initialize graph
dur = []
for epoch in range(args.n_epochs):
# compute vertex embedding.
update_hidden(g, None)
t0 = time.time()
permute = np.random.permutation(len(train_vs))
randv = train_vs[permute]
......@@ -162,14 +192,27 @@ def main(args, data):
if len(data) < args.num_parallel_subgraphs:
continue
subgs, seed_ids = subgraph_gen(g, data)
subgs, seed_ids = subgraph_gen(g, data, train_ctxs)
losses = []
i = 0
for subg, seed_id, label, d in zip(subgs, seed_ids, labels, data):
if args.gpu > 0:
ctx = mx.gpu(i % args.gpu)
copy_to_gpu(subg, ctx)
with mx.autograd.record():
logits = model(subg, seed_id)
if label.context != logits.context:
label = label.as_in_context(logits.context)
loss = mx.nd.softmax_cross_entropy(logits, label)
loss.backward()
trainer.step(d.shape[0])
train_loss += loss.asnumpy()[0]
losses.append(loss)
i = i + 1
if i % args.gpu == 0:
trainer.step(d.shape[0] * len(subgs))
for loss in losses:
train_loss += loss.asnumpy()[0]
losses = []
data = []
labels = []
......@@ -178,13 +221,48 @@ def main(args, data):
#eval_loss = eval_loss.asnumpy()[0]
eval_loss = 0
# compute vertex embedding.
infer_params = update_hidden_infer.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
update_hidden_infer(g, None)
dur.append(time.time() - t0)
print("Epoch {:05d} | Train Loss {:.4f} | Eval Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, train_loss, eval_loss, np.mean(dur), n_edges / np.mean(dur) / 1000))
class MXNetGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return mx.nd.contrib.getnnz(self._mat)
class GraphData:
def __init__(self, csr, num_feats):
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = mx.nd.arange(0, num_edges, step=1, repeat=1, dtype=np.int64)
csr = mx.nd.sparse.csr_matrix((edge_ids, csr.indices, csr.indptr), shape=csr.shape, dtype=np.int64)
self.graph = MXNetGraph(csr)
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.labels = mx.nd.floor(mx.nd.random.normal(loc=0, scale=10, shape=(csr.shape[0])))
self.num_labels = 10
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--graph-file", type=str, default="",
help="graph file")
parser.add_argument("--num-feats", type=int, default=10,
help="the number of features")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
......@@ -210,5 +288,10 @@ if __name__ == '__main__':
args = parser.parse_args()
# load and preprocess dataset
data = load_data(args)
if args.graph_file != '':
csr = mx.nd.load(args.graph_file)[0]
data = GraphData(csr, args.num_feats)
csr = None
else:
data = load_data(args)
main(args, data)
......@@ -81,9 +81,13 @@ class ImmutableGraphIndex(object):
NDArray
Teh edge id array.
"""
if len(u) == 0 or len(v) == 0:
return [], [], []
ids = mx.nd.contrib.edge_id(self._in_csr, v, u)
ids = ids.asnumpy()
return ids[ids >= 0]
v = v.asnumpy()
u = u.asnumpy()
return u[ids >= 0], v[ids >= 0], ids[ids >= 0]
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
......
......@@ -27,11 +27,11 @@ def sparse_matrix(data, index, shape, force_format=False):
raise TypeError('MXNet backend only supports CSR format,'
' but COO format is forced.')
coord = index[1]
return nd.sparse.csr_matrix((data, (coord[0], coord[1])), shape)
return nd.sparse.csr_matrix((data, (coord[0], coord[1])), tuple(shape))
elif fmt == 'csr':
indices = index[1]
indptr = index[2]
return nd.sparse.csr_matrix((data, indices, indptr), shape)
return nd.sparse.csr_matrix((data, indices, indptr), tuple(shape))
else:
raise TypeError('Invalid format: %s.' % fmt)
......@@ -65,7 +65,7 @@ def sum(input, dim):
return nd.sum(input, axis=dim)
def max(input, dim):
return nd.max(input, axis=dim)
return nd.max(input, axis=dim).asnumpy()[0]
def cat(seq, dim):
return nd.concat(*seq, dim=dim)
......@@ -131,7 +131,7 @@ def nonzero_1d(input):
def sort_1d(input):
# TODO: this isn't an ideal implementation.
val = nd.sort(input, is_ascend=True)
val = nd.sort(input, axis=None, is_ascend=True)
idx = nd.argsort(input, is_ascend=True)
idx = nd.cast(idx, dtype='int64')
return val, idx
......
......@@ -853,7 +853,7 @@ class DGLGraph(object):
"""
self._apply_edge_func = func
def apply_nodes(self, func="default", v=ALL):
def apply_nodes(self, func="default", v=ALL, inplace=False):
"""Apply the function on the node features.
Applying a None function will be ignored.
......@@ -865,7 +865,7 @@ class DGLGraph(object):
v : int, iterable of int, tensor, optional
The node id(s).
"""
self._internal_apply_nodes(v, func)
self._internal_apply_nodes(v, func, inplace=inplace)
def apply_edges(self, func="default", edges=ALL):
"""Apply the function on the edge features.
......@@ -1464,7 +1464,8 @@ class DGLGraph(object):
edges = F.tensor(edges)
return edges[e_mask]
def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None):
def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None,
inplace=False):
"""Internal apply nodes
Parameters
......@@ -1478,7 +1479,7 @@ class DGLGraph(object):
# Skip none function call.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
self.set_n_repr(reduce_accum, v, inplace=inplace)
return
# take out current node repr
curr_repr = self.get_n_repr(v)
......@@ -1491,4 +1492,4 @@ class DGLGraph(object):
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
self.set_n_repr(new_repr, v, inplace=inplace)
......@@ -62,6 +62,17 @@ class ImmutableGraphIndex(object):
"""Clear the graph."""
raise Exception('Immutable graph doesn\'t support clearing up')
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
# Immutable graph doesn't support multi-edge.
return False
def number_of_nodes(self):
"""Return the number of nodes.
......@@ -207,7 +218,7 @@ class ImmutableGraphIndex(object):
"""
u = F.tensor([u], dtype=F.int64)
v = F.tensor([v], dtype=F.int64)
id = self._sparse.edge_ids(u, v)
_, _, id = self._sparse.edge_ids(u, v)
return utils.toindex(id)
def edge_ids(self, u, v):
......@@ -223,12 +234,16 @@ class ImmutableGraphIndex(object):
Returns
-------
utils.Index
The edge id array.
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u = u.tousertensor()
v = v.tousertensor()
ids = self._sparse.edge_ids(u, v)
return utils.toindex(ids)
u, v, ids = self._sparse.edge_ids(u, v)
return utils.toindex(u), utils.toindex(v), utils.toindex(ids)
def in_edges(self, v):
"""Return the in edges of the node(s).
......
......@@ -119,4 +119,4 @@ class DGLSubGraph(DGLGraph):
self._parent._edge_frame[self._parent_eid]))
def map_to_subgraph_nid(self, parent_vids):
return map_to_subgraph_nid(self._graph, parent_vids)
return map_to_subgraph_nid(self._graph, utils.toindex(parent_vids))
......@@ -3,6 +3,7 @@ os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
from dgl.graph import DGLGraph
import scipy.sparse as spsp
D = 5
reduce_msg_shapes = set()
......@@ -26,20 +27,39 @@ def reduce_func(nodes):
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.ndata['h'] = ncol
return g
def generate_graph(grad=False, readonly=False):
if readonly:
row_idx = []
col_idx = []
for i in range(1, 9):
row_idx.append(0)
col_idx.append(i)
row_idx.append(i)
col_idx.append(9)
row_idx.append(9)
col_idx.append(0)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(10, 10))
g = DGLGraph(csr, readonly=True)
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.ndata['h'] = ncol
return g
else:
g = DGLGraph()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.ndata['h'] = ncol
return g
def test_batch_setter_getter():
def _pfc(x):
......@@ -121,7 +141,7 @@ def test_batch_setter_getter():
def test_batch_setter_autograd():
with mx.autograd.record():
g = generate_graph(grad=True)
g = generate_graph(grad=True, readonly=True)
h1 = g.ndata['h']
h1.attach_grad()
# partial set
......@@ -153,9 +173,9 @@ def test_batch_send():
v = mx.nd.array([9], dtype='int64')
g.send((u, v))
def test_batch_recv():
def check_batch_recv(readonly):
# basic recv test
g = generate_graph()
g = generate_graph(readonly=readonly)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
......@@ -167,8 +187,12 @@ def test_batch_recv():
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
def test_batch_recv():
check_batch_recv(True)
check_batch_recv(False)
def check_update_routines(readonly):
g = generate_graph(readonly=readonly)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
......@@ -201,13 +225,27 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def test_reduce_0deg():
g = DGLGraph()
g.add_nodes(5)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def test_update_routines():
check_update_routines(True)
check_update_routines(False)
def check_reduce_0deg(readonly):
if readonly:
row_idx = []
col_idx = []
for i in range(1, 5):
row_idx.append(i)
col_idx.append(0)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(5, 5))
g = DGLGraph(csr, readonly=True)
else:
g = DGLGraph()
g.add_nodes(5)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -220,10 +258,23 @@ def test_reduce_0deg():
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def test_reduce_0deg():
check_reduce_0deg(True)
check_reduce_0deg(False)
def check_pull_0deg(readonly):
if readonly:
row_idx = []
col_idx = []
row_idx.append(0)
col_idx.append(1)
ones = np.ones(shape=(len(row_idx)))
csr = spsp.csr_matrix((ones, (row_idx, col_idx)), shape=(2, 2))
g = DGLGraph(csr, readonly=True)
else:
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
......@@ -246,6 +297,10 @@ def test_pull_0deg():
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
def test_pull_0deg():
check_pull_0deg(True)
check_pull_0deg(False)
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
......
......@@ -67,7 +67,7 @@ def check_basics(g, ig):
assert g.has_edge_between(u, v) == ig.has_edge_between(u, v)
randv = utils.toindex(randv)
ids = g.edge_ids(randv, randv)[2].tolist()
assert sum(ig.edge_ids(randv, randv).tolist() == ids) == len(ids)
assert sum(ig.edge_ids(randv, randv)[2].tolist() == ids) == len(ids)
assert sum(g.has_edges_between(randv, randv).tolist() == ig.has_edges_between(randv, randv).tolist()) == len(randv)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment