Commit eb1acecd authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

[BugFix] Handle case when calling send with builtin message function (#449)

* fix send with builtin bug

* test cases

* remove todo comment
parent d1d580ec
...@@ -42,8 +42,10 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -42,8 +42,10 @@ def schedule_send(graph, u, v, eid, message_func):
message_func: callable or list of callable message_func: callable or list of callable
The message function The message function
""" """
# TODO(minjie): support builtin message func
message_func = _standardize_func_usage(message_func, 'message') message_func = _standardize_func_usage(message_func, 'message')
mfunc_is_list = utils.is_iterable(message_func)
if mfunc_is_list:
message_func = BundledFunction(message_func)
# vars # vars
var_nf = var.FEAT_DICT(graph._node_frame) var_nf = var.FEAT_DICT(graph._node_frame)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph._edge_frame)
......
...@@ -19,15 +19,21 @@ def generate_graph(): ...@@ -19,15 +19,21 @@ def generate_graph():
return g return g
def reducer_both(nodes): def reducer_both(nodes):
return {'h' : F.sum(nodes.mailbox['m'], 1)} return {'out' : F.sum(nodes.mailbox['m'], 1)}
def test_copy_src(): def test_copy_src():
# copy_src with both fields # copy_src with both fields
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m')) g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
# test with update_all
g.update_all() g.update_all()
assert F.allclose(g.ndata['h'], assert F.allclose(g.ndata.pop('out'),
F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# test with send and then recv
g.send()
g.recv()
assert F.allclose(g.ndata.pop('out'),
F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge(): def test_copy_edge():
...@@ -35,8 +41,14 @@ def test_copy_edge(): ...@@ -35,8 +41,14 @@ def test_copy_edge():
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m')) g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
# test with update_all
g.update_all() g.update_all()
assert F.allclose(g.ndata['h'], assert F.allclose(g.ndata.pop('out'),
F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# test with send and then recv
g.send()
g.recv()
assert F.allclose(g.ndata.pop('out'),
F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) F.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge(): def test_src_mul_edge():
...@@ -44,8 +56,14 @@ def test_src_mul_edge(): ...@@ -44,8 +56,14 @@ def test_src_mul_edge():
g = generate_graph() g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
# test with update_all
g.update_all() g.update_all()
assert F.allclose(g.ndata['h'], assert F.allclose(g.ndata.pop('out'),
F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
# test with send and then recv
g.send()
g.recv()
assert F.allclose(g.ndata.pop('out'),
F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) F.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment