Commit 7d04c8c9 authored by Minjie Wang's avatar Minjie Wang
Browse files

remove nonbatchable mode

parent 3a3e5d48
......@@ -50,11 +50,11 @@ class DGLGraph(object):
self._msg_frame = FrameRef()
self.reset_messages()
# registered functions
self._message_func = (None, None)
self._reduce_func = (None, None)
self._edge_func = (None, None)
self._apply_node_func = (None, None)
self._apply_edge_func = (None, None)
self._message_func = None
self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
......@@ -710,77 +710,57 @@ class DGLGraph(object):
else:
return self._edge_frame.select_rows(eid)
def register_edge_func(self,
edge_func,
batchable=False):
def register_edge_func(self, edge_func):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._edge_func = (edge_func, batchable)
self._edge_func = edge_func
def register_message_func(self,
message_func,
batchable=False):
def register_message_func(self, message_func):
"""Register global message function.
Parameters
----------
message_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._message_func = (message_func, batchable)
self._message_func = message_func
def register_reduce_func(self,
reduce_func,
batchable=False):
def register_reduce_func(self, reduce_func):
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
"""
self._reduce_func = (reduce_func, batchable)
self._reduce_func = reduce_func
def register_apply_node_func(self,
apply_node_func,
batchable=False):
def register_apply_node_func(self, apply_node_func):
"""Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_node_func = (apply_node_func, batchable)
self._apply_node_func = apply_node_func
def register_apply_edge_func(self,
apply_edge_func,
batchable=False):
def register_apply_edge_func(self, apply_edge_func):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_edge_func = (apply_edge_func, batchable)
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False):
def apply_nodes(self, v, apply_node_func="default"):
"""Apply the function on node representations.
Parameters
......@@ -789,27 +769,16 @@ class DGLGraph(object):
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
return
if batchable:
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
else:
raise RuntimeError('Disabled')
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False):
def apply_edges(self, u, v, apply_edge_func="default"):
"""Apply the function on edge representations.
Parameters
......@@ -820,27 +789,16 @@ class DGLGraph(object):
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func
apply_edge_func = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if batchable:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False):
def send(self, u, v, message_func="default"):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
......@@ -861,30 +819,13 @@ class DGLGraph(object):
The destination node(s).
message_func : callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if batchable:
self._batch_send(u, v, message_func)
else:
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
......@@ -908,7 +849,7 @@ class DGLGraph(object):
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False):
def update_edge(self, u=ALL, v=ALL, edge_func="default"):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
......@@ -927,29 +868,11 @@ class DGLGraph(object):
The destination node(s).
edge_func : callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
if edge_func == "default":
edge_func, batchable = self._edge_func
edge_func = self._edge_func
assert edge_func is not None
if batchable:
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
......@@ -975,8 +898,7 @@ class DGLGraph(object):
def recv(self,
u,
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
......@@ -1006,34 +928,15 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
if reduce_func == "default":
reduce_func, batchable = self._reduce_func
reduce_func = self._reduce_func
assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
if batchable:
self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func):
raise RuntimeError('Disabled')
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
self.apply_nodes(u, apply_node_func)
def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0:
......@@ -1105,8 +1008,7 @@ class DGLGraph(object):
u, v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Trigger the message function on u->v and update v.
Parameters
......@@ -1121,8 +1023,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
v = utils.toindex(v)
......@@ -1132,34 +1032,28 @@ class DGLGraph(object):
return
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor:
executor.run()
else:
self.send(u, v, message_func, batchable=batchable)
self.recv(unique_v, reduce_func, None, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
self.send(u, v, message_func)
self.recv(unique_v, reduce_func, None)
self.apply_nodes(unique_v, apply_node_func)
def pull(self,
v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Pull messages from the node's predecessors and then update it.
Parameters
......@@ -1172,24 +1066,20 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
v = utils.toindex(v)
if len(v) == 0:
return
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable)
self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func)
def push(self,
u,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send message from the node to its successors and update them.
Parameters
......@@ -1202,21 +1092,18 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
if len(u) == 0:
return
uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable)
reduce_func, apply_node_func)
def update_all(self,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send messages through all the edges and update all nodes.
Parameters
......@@ -1227,35 +1114,28 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor:
executor.run()
else:
self.send(ALL, ALL, message_func, batchable=batchable)
self.recv(ALL, reduce_func, None, batchable=batchable)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None)
self.apply_nodes(ALL, apply_node_func)
def propagate(self,
iterator='bfs',
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False,
**kwargs):
"""Propagate messages and update nodes using iterator.
......@@ -1274,8 +1154,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
......@@ -1288,7 +1166,7 @@ class DGLGraph(object):
# NOTE: the iteration can return multiple edges at each step.
for u, v in iterator:
self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable)
message_func, reduce_func, apply_node_func)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
......@@ -1350,15 +1228,3 @@ class DGLGraph(object):
[sg._parent_eid for sg in to_merge],
self._edge_frame.num_rows,
reduce_func)
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else:
return attr_dict
def _set_repr(attr_dict, attr):
if utils.is_dict_like(attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
......@@ -133,7 +133,7 @@ def test_batch_send():
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
......@@ -150,9 +150,9 @@ def test_batch_send():
def test_batch_recv():
# basic recv test
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
......@@ -163,9 +163,9 @@ def test_batch_recv():
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
# send_and_recv
reduce_msg_shapes.clear()
......@@ -209,7 +209,7 @@ def test_reduce_0deg():
return node + msgs.sum(1)
old_repr = th.randn(5, 5)
g.set_n_repr(old_repr)
g.update_all(_message, _reduce, batchable=True)
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1:], old_repr[1:])
......@@ -227,17 +227,17 @@ def test_pull_0deg():
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull(0, _message, _reduce, batchable=True)
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce, batchable=True)
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull([0, 1], _message, _reduce, batchable=True)
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
......
......@@ -129,7 +129,7 @@ def test_batch_send():
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
......@@ -145,8 +145,8 @@ def test_batch_send():
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
......@@ -157,8 +157,8 @@ def test_batch_recv():
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
# send_and_recv
reduce_msg_shapes.clear()
......
......@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src(src='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
......@@ -84,32 +84,32 @@ def test_copy_src():
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge(edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
......@@ -117,36 +117,36 @@ def test_copy_edge():
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_none, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_none)
g.update_all()
assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
......
......@@ -71,8 +71,8 @@ def test_batch_sendrecv():
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]
......@@ -94,8 +94,8 @@ def test_batch_propagate():
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
# get leaves.
order = []
......
......@@ -38,23 +38,23 @@ def test_update_all():
g = generate_graph()
# update all
v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func, batchable=True)
g.update_all(message_func, reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# update all with edge weights
v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func, batchable=True)
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
......@@ -85,25 +85,25 @@ def test_send_and_recv():
# send and recv
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# send and recv with edge weights
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
......@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
......@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
......@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True)
g.update_all(message_func_edge, reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
......@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
......@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
......@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
# run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True)
reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu + e_uv
def reduce_func(h, msgs):
return h + sum(msgs)
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, edge):
return src[__REPR__] + edge
def reduce_func_hybrid(node, msgs):
return node[__REPR__] + sum(msgs)
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_reduce_func(reduce_func_hybrid)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
test_sendrecv()
test_hybridrepr()
from dgl.graph import DGLGraph
def message_func(src, edge):
return src['h']
def reduce_func(node, msgs):
return {'m' : sum(msgs)}
def apply_func(node):
return {'h' : node['h'] + node['m']}
def message_dict_func(src, edge):
return {'m' : src['h']}
def reduce_dict_func(node, msgs):
return {'m' : sum([msg['m'] for msg in msgs])}
def apply_dict_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i]['h']) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def register1(g):
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_func)
def register2(g):
g.register_message_func(message_dict_func)
g.register_reduce_func(reduce_dict_func)
g.register_apply_node_func(apply_dict_func)
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def _test_multi_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# one-many
g.send(0, [1, 2, 3])
g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.send([6, 7, 8], 9)
g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.send([0, 0, 4, 5], [4, 5, 9, 9])
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def _test_update_routines(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send_and_recv(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.pull(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.push(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def test_sendrecv():
g = generate_graph()
register1(g)
_test_sendrecv(g)
g = generate_graph()
register2(g)
_test_sendrecv(g)
def test_multi_sendrecv():
g = generate_graph()
register1(g)
_test_multi_sendrecv(g)
g = generate_graph()
register2(g)
_test_multi_sendrecv(g)
def test_update_routines():
g = generate_graph()
register1(g)
_test_update_routines(g)
g = generate_graph()
register2(g)
_test_update_routines(g)
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
test_update_routines()
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu
def message_not_called(hu, e_uv):
assert False
return hu
def reduce_not_called(h, msgs):
assert False
return 0
def reduce_func(h, msgs):
return h + sum(msgs)
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
return g
def test_no_msg_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_apply_node_func(lambda h : h + 1)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(1, 9)
g.send(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
g.register_reduce_func(reduce_not_called)
g.recv(9)
def test_pull_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
assert False
return src
def _reduce(node, msgs):
assert False
return node
def _update(node):
return {'h': node['h'] * 2}
g.pull(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
if __name__ == '__main__':
test_no_msg_recv()
test_double_recv()
test_pull_0deg()
import dgl
import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, h=1)
g.add_edge(i, 9, h=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, h=10)
return g
def check(g, h, fld):
nh = [str(g.nodes[i][fld]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, __REPR__=10)
return g
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], __REPR__)
if __name__ == '__main__':
test_copy_src()
test_copy_edge()
test_src_mul_edge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment