Commit 2664ed2d authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

[Bugfix] Fix multiple send recv (#320)

* fix bug: frame append should check itself

* more test case

* fix multi-send-recv bug

* remv msg graph and clean up

* test cases

* more test case

* fix for batchedgraph

* fix bugs: converting from a graph with edges

* fix

* add more operators to utils.Index

* clear frame executor

* change message indicator to a graph level index

* fix test cases

* guard the case that mxnet does not support concat zero shape tensor

* fix bug: avoid convert full slice to numpy

* test multi-send-recv after conversion

* fix as request (partially)

* add dtype, ctx to full_1d

* add slice data to utils.Index

* fix

* more doc string

* fix as requested
parent 896dc50e
......@@ -618,7 +618,7 @@ def unique(input):
"""
pass
def full_1d(length, fill_value):
def full_1d(length, fill_value, dtype, ctx):
"""Create a 1D tensor full of the fill_value.
Parameters
......@@ -627,6 +627,10 @@ def full_1d(length, fill_value):
The length of the vector.
fill_value : int
The filled value.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
......
......@@ -171,8 +171,8 @@ def unique(input):
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)
def full_1d(length, fill_value):
return nd.full((length,), fill_value)
def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate
......
......@@ -140,8 +140,8 @@ def unsorted_1d_segment_mean(input, seg_id, n_segs, dim):
def unique(input):
return th.unique(input)
def full_1d(length, fill_value):
return th.full((length,), fill_value)
def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx)
def nonzero_1d(input):
return th.nonzero(input).squeeze()
......
......@@ -366,6 +366,20 @@ class Frame(MutableMapping):
# directly updating columns.
self._columns = {key: Column.create(data) for key, data in other.items()}
else:
# pad columns that are not provided in the other frame with initial values
for key, col in self.items():
if key not in other:
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(other.num_rows,) + scheme.shape, scheme.dtype,
ctx, slice(self._num_rows,
self._num_rows + other.num_rows)
)
other[key] = new_data
# append other to self
for key, col in other.items():
if key not in self._columns:
# the column does not exist; init a new column
......
......@@ -179,7 +179,7 @@ class DGLGraph(object):
# graph
self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly)
# frame
# node and edge frame
if node_frame is None:
self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
......@@ -188,10 +188,13 @@ class DGLGraph(object):
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# msg graph & frame
self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef()
self.reset_messages()
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
self._msg_frame.set_initializer(dgl.init.zero_initializer)
# registered functions
self._message_func = None
self._reduce_func = None
......@@ -243,7 +246,6 @@ class DGLGraph(object):
[1., 1., 1., 1.]])
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
if data is None:
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
......@@ -303,6 +305,9 @@ class DGLGraph(object):
self._edge_frame.add_rows(1)
else:
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)
def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes
......@@ -353,12 +358,16 @@ class DGLGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
num = max(len(u), len(v))
if data is None:
# Initialize feature placeholders if there are features existing
# NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(max(len(u), len(v)))
self._edge_frame.add_rows(num)
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)
def clear(self):
"""Remove all nodes and edges, as well as their features, from the
......@@ -382,7 +391,7 @@ class DGLGraph(object):
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_index = utils.zero_index(0)
self._msg_frame.clear()
def clear_cache(self):
......@@ -394,12 +403,6 @@ class DGLGraph(object):
"""
self._graph.clear_cache()
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self):
"""Return the number of nodes in the graph.
......@@ -1168,7 +1171,9 @@ class DGLGraph(object):
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
# copy attributes
def _batcher(lst):
if F.is_tensor(lst[0]):
......@@ -1225,7 +1230,8 @@ class DGLGraph(object):
self._graph.from_scipy_sparse_matrix(a)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
def node_attr_schemes(self):
"""Return the node feature schemes.
......@@ -1934,7 +1940,7 @@ class DGLGraph(object):
message_func = self._message_func
if is_all(edges):
eid = ALL
eid = utils.toindex(slice(0, self.number_of_edges()))
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
......@@ -1946,14 +1952,15 @@ class DGLGraph(object):
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
if len(eid) == 0:
# no edge to be triggered
return
with ir.prog() as prog:
scheduler.schedule_send(graph=self, u=u, v=v, eid=eid,
message_func=message_func)
Runtime.run(prog)
# update message graph and frame
self._msg_graph.add_edges(u, v)
def recv(self,
v=ALL,
reduce_func="default",
......@@ -2039,10 +2046,6 @@ class DGLGraph(object):
apply_node_func = self._apply_node_func
assert reduce_func is not None
if self._msg_frame.num_rows == 0:
# no message has ever been sent
return
if is_all(v):
v = F.arange(0, self.number_of_nodes())
elif isinstance(v, int):
......@@ -2060,9 +2063,6 @@ class DGLGraph(object):
inplace=inplace)
Runtime.run(prog)
# FIXME(minjie): multi send bug
self.reset_messages()
def send_and_recv(self,
edges,
message_func="default",
......
......@@ -34,6 +34,7 @@ class OpCode(object):
WRITE_DICT_ = 24
APPEND_ROW_ = 25
WRITE_ROW_INPLACE_ = 26
CLEAR_FRAME_ = 27
class Executor(object):
@abstractmethod
......@@ -645,3 +646,32 @@ IR_REGISTRY[OpCode.APPEND_ROW_] = {
def APPEND_ROW_(fd1, fd2):
reg = IR_REGISTRY[OpCode.APPEND_ROW_]
get_current_prog().issue(reg['executor_cls'](fd1, fd2))
class ClearFrame_Executor(Executor):
def __init__(self, fd):
self.fd = fd
def opcode(self):
return OpCode.CLEAR_FRAME_
def arg_vars(self):
return [self.fd]
def ret_var(self):
return None
def run(self):
frame = self.fd.data
num_rows = frame.num_rows
frame.clear()
frame.add_rows(num_rows)
IR_REGISTRY[OpCode.CLEAR_FRAME_] = {
'name': 'CLEAR_FRAME_',
'args_type': [VarType.FEAT_DICT],
'ret_type': None,
'executor_cls': ClearFrame_Executor,
}
def CLEAR_FRAME_(fd):
reg = IR_REGISTRY[OpCode.CLEAR_FRAME_]
get_current_prog().issue(reg['executor_cls'](fd))
......@@ -44,15 +44,16 @@ def schedule_send(graph, u, v, eid, message_func):
# TODO(minjie): support builtin message func
message_func = _standardize_func_usage(message_func, 'message')
# vars
nf = var.FEAT_DICT(graph._node_frame)
ef = var.FEAT_DICT(graph._edge_frame)
mf = var.FEAT_DICT(graph._msg_frame)
u = var.IDX(u)
v = var.IDX(v)
eid = var.IDX(eid)
msg = _gen_send(graph, nf, ef, u, v, eid, message_func)
# TODO: handle duplicate messages
ir.APPEND_ROW_(mf, msg)
var_nf = var.FEAT_DICT(graph._node_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_mf = var.FEAT_DICT(graph._msg_frame)
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
msg = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1)
def schedule_recv(graph,
recv_nodes,
......@@ -74,9 +75,16 @@ def schedule_recv(graph,
inplace: bool
If True, the update will be done in place
"""
src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes.
src, dst, eid = graph._graph.in_edges(recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._msg_index.get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx)
if len(eid) == 0:
# Downgrade to apply nodes if
# 1) all recv nodes are 0-degree nodes
# 2) no send has been called
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else:
......@@ -86,13 +94,19 @@ def schedule_recv(graph,
recv_nodes = utils.toindex(recv_nodes)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# reduce
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._msg_index = graph._msg_index.set_items(eid, 0)
if not graph._msg_index.has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))
def schedule_snr(graph,
edge_tuples,
......@@ -426,13 +440,14 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = edge_tuples
_, dst, eid = edge_tuples
rfunc = _standardize_func_usage(reduce_func, 'reduce')
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
# TODO(minjie): should replace this with an IR call to make the program
# stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars
......@@ -444,8 +459,8 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
# FIXME: refactor this when fixing the multi-recv bug
inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes)
inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, eid, dst,
recv_nodes)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0:
......@@ -456,7 +471,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst,
recv_nodes, nf, msg, out)
return out
......
......@@ -15,9 +15,10 @@ class Index(object):
self._initialize_data(data)
def _initialize_data(self, data):
self._pydata = None # a numpy type data or a slice
self._pydata = None # a numpy type data
self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray
self._slice_data = None # a slice type data
self._dispatch(data)
def __iter__(self):
......@@ -25,12 +26,9 @@ class Index(object):
yield int(i)
def __len__(self):
if self._pydata is not None and isinstance(self._pydata, slice):
slc = self._pydata
if slc.step is None:
return slc.stop - slc.start
else:
return (slc.stop - slc.start) // slc.step
if self._slice_data is not None:
slc = self._slice_data
return slc.stop - slc.start
elif self._pydata is not None:
return len(self._pydata)
elif len(self._user_tensor_data) > 0:
......@@ -60,7 +58,9 @@ class Index(object):
self._dgl_tensor_data = data
elif isinstance(data, slice):
# save it in the _pydata temporarily; materialize it if `tonumpy` is called
self._pydata = data
assert data.step == 1 or data.step is None, \
"step for slice type must be 1"
self._slice_data = slice(data.start, data.stop)
else:
try:
self._pydata = np.array([int(data)]).astype(np.int64)
......@@ -75,18 +75,18 @@ class Index(object):
raise DGLError('Error index data: %s' % str(data))
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._pydata)
def tonumpy(self):
"""Convert to a numpy ndarray."""
if self._pydata is None:
if self._dgl_tensor_data is not None:
if self._slice_data is not None:
slc = self._slice_data
self._pydata = np.arange(slc.start, slc.stop).astype(np.int64)
elif self._dgl_tensor_data is not None:
self._pydata = self._dgl_tensor_data.asnumpy()
else:
data = self.tousertensor()
self._pydata = F.zerocopy_to_numpy(data)
elif isinstance(self._pydata, slice):
# convert it to numpy array
slc = self._pydata
self._pydata = np.arange(slc.start, slc.stop, slc.step).astype(np.int64)
return self._pydata
def tousertensor(self, ctx=None):
......@@ -116,9 +116,9 @@ class Index(object):
self._dgl_tensor_data = nd.from_dlpack(dl)
return self._dgl_tensor_data
def is_slice(self, start, stop, step=None):
return (isinstance(self._pydata, slice)
and self._pydata == slice(start, stop, step))
def is_slice(self, start, stop):
"""Check if Index wraps a slice data with given start and stop"""
return self._slice_data == slice(start, stop)
def __getstate__(self):
return self.tousertensor()
......@@ -126,9 +126,99 @@ class Index(object):
def __setstate__(self, state):
self._initialize_data(state)
def get_items(self, index):
"""Return values at given positions of an Index
Parameters
----------
index: utils.Index
Returns
-------
utils.Index
"""
if index._slice_data is None:
tensor = self.tousertensor()
index = index.tousertensor()
return Index(F.gather_row(tensor, index))
elif self._slice_data is None:
tensor = self.tousertensor()
index = index._slice_data
return Index(F.narrow_row(tensor, index.start, index.stop))
else:
# both self and index wrap a slice object, then return another
# Index wrapping a slice
start = self._slicedata.start
index = index._slice_data
return Index(slice(start + index.start, start + index.stop))
def set_items(self, index, value):
"""Set values at given positions of an Index. Set is not done in place,
instead, a new Index object will be returned.
Parameters
----------
index: utils.Index
Positions to set values
value: int or utils.Index
Values to set. If value is an integer, then all positions are set
to the same value
Returns
-------
utils.Index
"""
tensor = self.tousertensor()
index = index.tousertensor()
if isinstance(value, int):
value = F.full_1d(len(index), value, dtype=F.int64, ctx=F.cpu())
else:
value = value.tousertensor()
return Index(F.scatter_row(tensor, index, value))
def append_zeros(self, num):
"""Append zeros to an Index
Parameters
----------
num: int
number of zeros to append
"""
if num == 0:
return self
new_items = F.zeros((num,), dtype=F.int64, ctx=F.cpu())
if len(self) == 0:
return Index(new_items)
else:
tensor = self.tousertensor()
tensor = F.cat((tensor, new_items), dim=0)
return Index(tensor)
def nonzero(self):
"""Return the nonzero positions"""
tensor = self.tousertensor()
mask = F.nonzero_1d(tensor != 0)
return Index(mask)
def has_nonzero(self):
"""Check if there is any nonzero value in this Index"""
tensor = self.tousertensor()
return F.sum(tensor, 0) > 0
def toindex(x):
return x if isinstance(x, Index) else Index(x)
def zero_index(size):
"""Create a index with provided size initialized to zero
Parameters
----------
size: int
"""
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu()))
class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys):
......
......@@ -490,34 +490,6 @@ def test_pull_0deg():
# non-0deg check: not touched
assert U.allclose(new[1], old[1])
def _disabled_test_send_twice():
# TODO(minjie): please re-enable this unittest after the send code problem is fixed.
g = DGLGraph()
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(edges):
return {'a': edges.src['a']}
def _message_b(edges):
return {'a': edges.src['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
old_repr = th.randn(3, 5)
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((0, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert U.allclose(new_repr[1], old_repr[0] * 3)
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((2, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph():
g = DGLGraph(multigraph=True)
g.add_nodes(3)
......@@ -614,6 +586,10 @@ def test_dynamic_addition():
g.edges[4].data['h1'] = th.randn(1, D)
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
# test add edge with part of the features
g.add_edge(2, 1, {'h1': th.randn(1, D)})
assert len(g.edata['h1']) == len(g.edata['h2'])
def test_repr():
G = dgl.DGLGraph()
......
import torch as th
from torch.autograd import Variable
import numpy as np
import dgl
from dgl.graph import DGLGraph
import utils as U
from collections import defaultdict as ddict
import scipy.sparse as sp
D = 5
def message_func(edges):
assert len(edges.src['h'].shape) == 2
assert edges.src['h'].shape[1] == D
return {'m' : edges.src['h']}
def reduce_func(nodes):
msgs = nodes.mailbox['m']
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'accum' : th.sum(msgs, 1)}
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# 16 edges
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
ncol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(16, D), requires_grad=grad)
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
g.ndata['h'] = ncol
g.edata['w'] = ecol
return g
def test_multi_send():
g = generate_graph()
def _fmsg(edges):
assert edges.src['h'].shape == (5, D)
return {'m' : edges.src['h']}
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.send((u, v))
# duplicate send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.send((u, v))
# send more
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.send((u, v))
# check if message indicator is as expected
expected = th.zeros((g.number_of_edges(),), dtype=th.int64)
eid = g.edge_ids([0, 0, 0, 0, 0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 9, 9, 9, 9, 9])
expected[eid] = 1
assert th.equal(g._msg_index.tousertensor(), expected)
def test_multi_recv():
# basic recv test
g = generate_graph()
h = g.ndata['h']
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
expected = th.zeros((g.number_of_edges(),), dtype=th.int64)
# two separate round of send and recv
u = [4, 5, 6]
v = [9]
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert th.equal(g._msg_index.tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert th.equal(g._msg_index.tousertensor(), expected)
u = [0]
v = [1, 2, 3]
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert th.equal(g._msg_index.tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert th.equal(g._msg_index.tousertensor(), expected)
h1 = g.ndata['h']
# one send, two recv
g.ndata['h'] = h
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert th.equal(g._msg_index.tousertensor(), expected)
u = [4, 5, 6]
v = [9]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert th.equal(g._msg_index.tousertensor(), expected)
u = [0]
v = [1, 2, 3]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert th.equal(g._msg_index.tousertensor(), expected)
h2 = g.ndata['h']
assert U.allclose(h1, h2)
def test_multi_recv_0deg():
# test recv with 0deg nodes;
g = DGLGraph()
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(_init2)
g.add_nodes(2)
g.add_edge(0, 1)
# recv both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.ndata['h'] = old
g.send((0, 1))
g.recv([0, 1])
new = g.ndata['h']
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# recv again on zero degree node
g.recv([0])
assert U.allclose(g.nodes[0].data['h'], th.full((5,), 8))
# recv again on node with no incoming message
g.recv([1])
assert U.allclose(g.nodes[1].data['h'], th.sum(old, 0) * 4)
def test_send_twice_different_shape():
g = generate_graph()
def _message_1(edges):
return {'h': edges.src['h']}
def _message_2(edges):
return {'h': th.cat((edges.src['h'], edges.data['w']), dim=1)}
g.send(message_func=_message_1)
g.send(message_func=_message_2)
def test_send_twice_different_msg():
g = DGLGraph()
g.set_n_initializer(dgl.init.zero_initializer)
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(edges):
return {'a': edges.src['a']}
def _message_b(edges):
return {'a': edges.src['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
old_repr = th.randn(3, 5)
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((0, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert U.allclose(new_repr[1], old_repr[0] * 3)
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((2, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert U.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_twice_different_field():
g = DGLGraph()
g.set_n_initializer(dgl.init.zero_initializer)
g.add_nodes(2)
g.add_edge(0, 1)
def _message_a(edges):
return {'a': edges.src['a']}
def _message_b(edges):
return {'b': edges.src['b']}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].sum(1), 'b': nodes.mailbox['b'].sum(1)}
old_a = th.randn(2, 5)
old_b = th.randn(2, 5)
g.set_n_repr({'a': old_a, 'b': old_b})
g.send((0, 1), _message_a)
g.send((0, 1), _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr['a'][1], old_a[0])
assert th.allclose(new_repr['b'][1], old_b[0])
def test_dynamic_addition():
N = 3
D = 1
g = DGLGraph()
def _init(shape, dtype, ctx, ids):
return th.randn(shape, dtype=dtype, device=ctx)
g.set_n_initializer(_init)
g.set_e_initializer(_init)
def _message(edges):
return {'m' : edges.src['h1'] + edges.dst['h2'] + edges.data['h1'] +
edges.data['h2']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h']}
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
# add nodes and edges
g.add_nodes(N)
g.ndata.update({'h1': th.randn(N, D),
'h2': th.randn(N, D)})
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(1, 0)
g.edata.update({'h1': th.randn(2, D),
'h2': th.randn(2, D)})
g.send()
expected = th.ones((g.number_of_edges(),), dtype=th.int64)
assert th.equal(g._msg_index.tousertensor(), expected)
# add more edges
g.add_edges([0, 2], [2, 0], {'h1': th.randn(2, D)})
g.send(([0, 2], [2, 0]))
g.recv(0)
g.add_edge(1, 2)
g.edges[4].data['h1'] = th.randn(1, D)
g.send((1, 2))
g.recv([1, 2])
h = g.ndata.pop('h')
# a complete round of send and recv
g.send()
g.recv()
assert U.allclose(h, g.ndata['h'])
def test_recv_no_send():
g = generate_graph()
g.recv(1, reduce_func)
# test recv after clear
g.clear()
g.add_nodes(3)
g.add_edges([0, 1], [1, 2])
g.set_n_initializer(dgl.init.zero_initializer)
g.ndata['h'] = th.randn(3, D)
g.send((1, 2), message_func)
expected = th.zeros((2,), dtype=th.int64)
expected[1] = 1
assert th.equal(g._msg_index.tousertensor(), expected)
g.recv(2, reduce_func)
expected[1] = 0
assert th.equal(g._msg_index.tousertensor(), expected)
def test_send_recv_after_conversion():
# test send and recv after converting from a graph with edges
g = generate_graph()
# nx graph
nxg = g.to_networkx(node_attrs=['h'])
g1 = DGLGraph()
# some random node and edges
g1.add_nodes(4)
g1.add_edges([1, 2], [2, 3])
g1.set_n_initializer(dgl.init.zero_initializer)
g1.from_networkx(nxg, node_attrs=['h'])
# sparse matrix
row, col= g.all_edges()
data = range(len(row))
n = g.number_of_nodes()
a = sp.coo_matrix((data, (row, col)), shape=(n, n))
g2 = DGLGraph()
# some random node and edges
g2.add_nodes(5)
g2.add_edges([1, 2, 4], [2, 3, 0])
g2.set_n_initializer(dgl.init.zero_initializer)
g2.from_scipy_sparse_matrix(a)
g2.ndata['h'] = g.ndata['h']
# on dgl graph
g.send(message_func=message_func)
g.recv([0, 1, 3, 5], reduce_func=reduce_func,
apply_node_func=apply_node_func)
g.recv([0, 2, 4, 8], reduce_func=reduce_func,
apply_node_func=apply_node_func)
# nx
g1.send(message_func=message_func)
g1.recv([0, 1, 3, 5], reduce_func=reduce_func,
apply_node_func=apply_node_func)
g1.recv([0, 2, 4, 8], reduce_func=reduce_func,
apply_node_func=apply_node_func)
# sparse matrix
g2.send(message_func=message_func)
g2.recv([0, 1, 3, 5], reduce_func=reduce_func,
apply_node_func=apply_node_func)
g2.recv([0, 2, 4, 8], reduce_func=reduce_func,
apply_node_func=apply_node_func)
assert U.allclose(g.ndata['h'], g1.ndata['h'])
assert U.allclose(g.ndata['h'], g2.ndata['h'])
if __name__ == '__main__':
test_multi_send()
test_multi_recv()
test_multi_recv_0deg()
test_dynamic_addition()
test_send_twice_different_shape()
test_send_twice_different_msg()
test_send_twice_different_field()
test_recv_no_send()
test_send_recv_after_conversion()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment