"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c72e34308509ea7ac09f52b0440be469eb3f650c"
Commit 7d04c8c9 authored by Minjie Wang's avatar Minjie Wang
Browse files

remove nonbatchable mode

parent 3a3e5d48
...@@ -50,11 +50,11 @@ class DGLGraph(object): ...@@ -50,11 +50,11 @@ class DGLGraph(object):
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self.reset_messages() self.reset_messages()
# registered functions # registered functions
self._message_func = (None, None) self._message_func = None
self._reduce_func = (None, None) self._reduce_func = None
self._edge_func = (None, None) self._edge_func = None
self._apply_node_func = (None, None) self._apply_node_func = None
self._apply_edge_func = (None, None) self._apply_edge_func = None
def add_nodes(self, num, reprs=None): def add_nodes(self, num, reprs=None):
"""Add nodes. """Add nodes.
...@@ -710,77 +710,57 @@ class DGLGraph(object): ...@@ -710,77 +710,57 @@ class DGLGraph(object):
else: else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def register_edge_func(self, def register_edge_func(self, edge_func):
edge_func,
batchable=False):
"""Register global edge update function. """Register global edge update function.
Parameters Parameters
---------- ----------
edge_func : callable edge_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._edge_func = (edge_func, batchable) self._edge_func = edge_func
def register_message_func(self, def register_message_func(self, message_func):
message_func,
batchable=False):
"""Register global message function. """Register global message function.
Parameters Parameters
---------- ----------
message_func : callable message_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._message_func = (message_func, batchable) self._message_func = message_func
def register_reduce_func(self, def register_reduce_func(self, reduce_func):
reduce_func,
batchable=False):
"""Register global message reduce function. """Register global message reduce function.
Parameters Parameters
---------- ----------
reduce_func : str or callable reduce_func : str or callable
Reduce function on incoming edges. Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
""" """
self._reduce_func = (reduce_func, batchable) self._reduce_func = reduce_func
def register_apply_node_func(self, def register_apply_node_func(self, apply_node_func):
apply_node_func,
batchable=False):
"""Register global node apply function. """Register global node apply function.
Parameters Parameters
---------- ----------
apply_node_func : callable apply_node_func : callable
Apply function on the node. Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_node_func = (apply_node_func, batchable) self._apply_node_func = apply_node_func
def register_apply_edge_func(self, def register_apply_edge_func(self, apply_edge_func):
apply_edge_func,
batchable=False):
"""Register global edge apply function. """Register global edge apply function.
Parameters Parameters
---------- ----------
apply_edge_func : callable apply_edge_func : callable
Apply function on the edge. Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_edge_func = (apply_edge_func, batchable) self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False): def apply_nodes(self, v, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Parameters Parameters
...@@ -789,27 +769,16 @@ class DGLGraph(object): ...@@ -789,27 +769,16 @@ class DGLGraph(object):
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
batchable : bool
Whether the provided function allows batch computing.
""" """
if apply_node_func == "default": if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func apply_node_func = self._apply_node_func
if not apply_node_func: if not apply_node_func:
# Skip none function call. # Skip none function call.
return return
if batchable: new_repr = apply_node_func(self.get_n_repr(v))
new_repr = apply_node_func(self.get_n_repr(v)) self.set_n_repr(new_repr, v)
self.set_n_repr(new_repr, v)
else:
raise RuntimeError('Disabled')
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False): def apply_edges(self, u, v, apply_edge_func="default"):
"""Apply the function on edge representations. """Apply the function on edge representations.
Parameters Parameters
...@@ -820,27 +789,16 @@ class DGLGraph(object): ...@@ -820,27 +789,16 @@ class DGLGraph(object):
The dst node id(s). The dst node id(s).
apply_edge_func : callable apply_edge_func : callable
The apply edge function. The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
""" """
if apply_edge_func == "default": if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func apply_edge_func = self._apply_edge_func
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
if batchable: new_repr = apply_edge_func(self.get_e_repr(u, v))
new_repr = apply_edge_func(self.get_e_repr(u, v)) self.set_e_repr(new_repr, u, v)
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False): def send(self, u, v, message_func="default"):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -861,30 +819,13 @@ class DGLGraph(object): ...@@ -861,30 +819,13 @@ class DGLGraph(object):
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
batchable : bool
Whether the function allows batched computation.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
assert message_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)): if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func) message_func = BundledMessageFunction(message_func)
if batchable: self._batch_send(u, v, message_func)
self._batch_send(u, v, message_func)
else:
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_send(self, u, v, message_func): def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
...@@ -908,7 +849,7 @@ class DGLGraph(object): ...@@ -908,7 +849,7 @@ class DGLGraph(object):
else: else:
self._msg_frame.append({__MSG__ : msgs}) self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False): def update_edge(self, u=ALL, v=ALL, edge_func="default"):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -927,29 +868,11 @@ class DGLGraph(object): ...@@ -927,29 +868,11 @@ class DGLGraph(object):
The destination node(s). The destination node(s).
edge_func : callable edge_func : callable
The update function. The update function.
batchable : bool
Whether the function allows batched computation.
""" """
if edge_func == "default": if edge_func == "default":
edge_func, batchable = self._edge_func edge_func = self._edge_func
assert edge_func is not None assert edge_func is not None
if batchable: self._batch_update_edge(u, v, edge_func)
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
...@@ -975,8 +898,7 @@ class DGLGraph(object): ...@@ -975,8 +898,7 @@ class DGLGraph(object):
def recv(self, def recv(self,
u, u,
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Receive and reduce in-coming messages and update representation on node u. """Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors It computes the new node state using the messages sent from the predecessors
...@@ -1006,34 +928,15 @@ class DGLGraph(object): ...@@ -1006,34 +928,15 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
""" """
if reduce_func == "default": if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func = self._reduce_func
assert reduce_func is not None assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)): if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func) reduce_func = BundledReduceFunction(reduce_func)
if batchable: self._batch_recv(u, reduce_func)
self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func)
# optional apply nodes # optional apply nodes
self.apply_nodes(u, apply_node_func, batchable) self.apply_nodes(u, apply_node_func)
def _nonbatch_recv(self, u, reduce_func):
raise RuntimeError('Disabled')
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
def _batch_recv(self, v, reduce_func): def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0: if self._msg_frame.num_rows == 0:
...@@ -1105,8 +1008,7 @@ class DGLGraph(object): ...@@ -1105,8 +1008,7 @@ class DGLGraph(object):
u, v, u, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
Parameters Parameters
...@@ -1121,8 +1023,6 @@ class DGLGraph(object): ...@@ -1121,8 +1023,6 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
...@@ -1132,34 +1032,28 @@ class DGLGraph(object): ...@@ -1132,34 +1032,28 @@ class DGLGraph(object):
return return
unique_v = utils.toindex(F.unique(v.tousertensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
reduce_func, _ = self._reduce_func reduce_func = self._reduce_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if batchable: executor = scheduler.get_executor(
executor = scheduler.get_executor( 'send_and_recv', self, src=u, dst=v,
'send_and_recv', self, src=u, dst=v, message_func=message_func, reduce_func=reduce_func)
message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor: if executor:
executor.run() executor.run()
else: else:
self.send(u, v, message_func, batchable=batchable) self.send(u, v, message_func)
self.recv(unique_v, reduce_func, None, batchable=batchable) self.recv(unique_v, reduce_func, None)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable) self.apply_nodes(unique_v, apply_node_func)
def pull(self, def pull(self,
v, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
Parameters Parameters
...@@ -1172,24 +1066,20 @@ class DGLGraph(object): ...@@ -1172,24 +1066,20 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
return return
uu, vv, _ = self._graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
apply_node_func=None, batchable=batchable)
unique_v = F.unique(v.tousertensor()) unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable) self.apply_nodes(unique_v, apply_node_func)
def push(self, def push(self,
u, u,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
Parameters Parameters
...@@ -1202,21 +1092,18 @@ class DGLGraph(object): ...@@ -1202,21 +1092,18 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable apply_node_func : callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
u = utils.toindex(u) u = utils.toindex(u)
if len(u) == 0: if len(u) == 0:
return return
uu, vv, _ = self._graph.out_edges(u) uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func, self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable) reduce_func, apply_node_func)
def update_all(self, def update_all(self,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send messages through all the edges and update all nodes. """Send messages through all the edges and update all nodes.
Parameters Parameters
...@@ -1227,35 +1114,28 @@ class DGLGraph(object): ...@@ -1227,35 +1114,28 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
reduce_func, _ = self._reduce_func reduce_func = self._reduce_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if batchable: executor = scheduler.get_executor(
executor = scheduler.get_executor( "update_all", self, message_func=message_func, reduce_func=reduce_func)
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor: if executor:
executor.run() executor.run()
else: else:
self.send(ALL, ALL, message_func, batchable=batchable) self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None, batchable=batchable) self.recv(ALL, reduce_func, None)
self.apply_nodes(ALL, apply_node_func, batchable=batchable) self.apply_nodes(ALL, apply_node_func)
def propagate(self, def propagate(self,
iterator='bfs', iterator='bfs',
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default",
batchable=False,
**kwargs): **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
...@@ -1274,8 +1154,6 @@ class DGLGraph(object): ...@@ -1274,8 +1154,6 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : str or callable apply_node_func : str or callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps. iterator : str or generator of steps.
The iterator of the graph. The iterator of the graph.
kwargs : keyword arguments, optional kwargs : keyword arguments, optional
...@@ -1288,7 +1166,7 @@ class DGLGraph(object): ...@@ -1288,7 +1166,7 @@ class DGLGraph(object):
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
for u, v in iterator: for u, v in iterator:
self.send_and_recv(u, v, self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable) message_func, reduce_func, apply_node_func)
def subgraph(self, nodes): def subgraph(self, nodes):
"""Generate the subgraph among the given nodes. """Generate the subgraph among the given nodes.
...@@ -1350,15 +1228,3 @@ class DGLGraph(object): ...@@ -1350,15 +1228,3 @@ class DGLGraph(object):
[sg._parent_eid for sg in to_merge], [sg._parent_eid for sg in to_merge],
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else:
return attr_dict
def _set_repr(attr_dict, attr):
if utils.is_dict_like(attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
...@@ -133,7 +133,7 @@ def test_batch_send(): ...@@ -133,7 +133,7 @@ def test_batch_send():
def _fmsg(src, edge): def _fmsg(src, edge):
assert src['h'].shape == (5, D) assert src['h'].shape == (5, D)
return {'m' : src['h']} return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True) g.register_message_func(_fmsg)
# many-many send # many-many send
u = th.tensor([0, 0, 0, 0, 0]) u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
...@@ -150,9 +150,9 @@ def test_batch_send(): ...@@ -150,9 +150,9 @@ def test_batch_send():
def test_batch_recv(): def test_batch_recv():
# basic recv test # basic recv test
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func)
g.register_reduce_func(reduce_func, batchable=True) g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func, batchable=True) g.register_apply_node_func(apply_node_func)
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -163,9 +163,9 @@ def test_batch_recv(): ...@@ -163,9 +163,9 @@ def test_batch_recv():
def test_update_routines(): def test_update_routines():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func)
g.register_reduce_func(reduce_func, batchable=True) g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func, batchable=True) g.register_apply_node_func(apply_node_func)
# send_and_recv # send_and_recv
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -209,7 +209,7 @@ def test_reduce_0deg(): ...@@ -209,7 +209,7 @@ def test_reduce_0deg():
return node + msgs.sum(1) return node + msgs.sum(1)
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.set_n_repr(old_repr) g.set_n_repr(old_repr)
g.update_all(_message, _reduce, batchable=True) g.update_all(_message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
...@@ -227,17 +227,17 @@ def test_pull_0deg(): ...@@ -227,17 +227,17 @@ def test_pull_0deg():
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr(old_repr)
g.pull(0, _message, _reduce, batchable=True) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1]) assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce, batchable=True) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr(old_repr)
g.pull([0, 1], _message, _reduce, batchable=True) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
......
...@@ -129,7 +129,7 @@ def test_batch_send(): ...@@ -129,7 +129,7 @@ def test_batch_send():
def _fmsg(hu, edge): def _fmsg(hu, edge):
assert hu.shape == (5, D) assert hu.shape == (5, D)
return hu return hu
g.register_message_func(_fmsg, batchable=True) g.register_message_func(_fmsg)
# many-many send # many-many send
u = th.tensor([0, 0, 0, 0, 0]) u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
...@@ -145,8 +145,8 @@ def test_batch_send(): ...@@ -145,8 +145,8 @@ def test_batch_send():
def test_batch_recv(): def test_batch_recv():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func)
g.register_reduce_func(reduce_func, batchable=True) g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -157,8 +157,8 @@ def test_batch_recv(): ...@@ -157,8 +157,8 @@ def test_batch_recv():
def test_update_routines(): def test_update_routines():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func)
g.register_reduce_func(reduce_func, batchable=True) g.register_reduce_func(reduce_func)
# send_and_recv # send_and_recv
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
......
...@@ -51,32 +51,32 @@ def reducer_none(node, msgs): ...@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
def test_copy_src(): def test_copy_src():
# copy_src with both fields # copy_src with both fields
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True) g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with only src field; the out field should use anonymous repr # copy_src with only src field; the out field should use anonymous repr
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=True) g.register_message_func(fn.copy_src(src='h'))
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with no src field; should use anonymous repr # copy_src with no src field; should use anonymous repr
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=True) g.register_message_func(fn.copy_src(out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy src with no fields; # copy src with no fields;
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=True) g.register_message_func(fn.copy_src())
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
...@@ -84,32 +84,32 @@ def test_copy_src(): ...@@ -84,32 +84,32 @@ def test_copy_src():
def test_copy_edge(): def test_copy_edge():
# copy_edge with both fields # copy_edge with both fields
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True) g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with only edge field; the out field should use anonymous repr # copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph() g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=True) g.register_message_func(fn.copy_edge(edge='h'))
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with no edge field; should use anonymous repr # copy_edge with no edge field; should use anonymous repr
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=True) g.register_message_func(fn.copy_edge(out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy edge with no fields; # copy edge with no fields;
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=True) g.register_message_func(fn.copy_edge())
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
...@@ -117,36 +117,36 @@ def test_copy_edge(): ...@@ -117,36 +117,36 @@ def test_copy_edge():
def test_src_mul_edge(): def test_src_mul_edge():
# src_mul_edge with all fields # src_mul_edge with all fields
g = generate_graph() g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True) g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph() g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True) g.register_message_func(fn.src_mul_edge(src='h', edge='h'))
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=True) g.register_message_func(fn.src_mul_edge(out='m'))
g.register_reduce_func(reducer_both, batchable=True) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True) g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_out, batchable=True) g.register_reduce_func(reducer_out)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1() g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True) g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_none, batchable=True) g.register_reduce_func(reducer_none)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr(), assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
......
...@@ -71,8 +71,8 @@ def test_batch_sendrecv(): ...@@ -71,8 +71,8 @@ def test_batch_sendrecv():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True) bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
e1 = [(3, 1), (4, 1)] e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)] e2 = [(2, 4), (0, 4)]
...@@ -94,8 +94,8 @@ def test_batch_propagate(): ...@@ -94,8 +94,8 @@ def test_batch_propagate():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True) bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True) bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
# get leaves. # get leaves.
order = [] order = []
......
...@@ -38,23 +38,23 @@ def test_update_all(): ...@@ -38,23 +38,23 @@ def test_update_all():
g = generate_graph() g = generate_graph()
# update all # update all
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True) g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func, batchable=True) g.update_all(message_func, reduce_func, apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
# update all with edge weights # update all with edge weights
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'), g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True) fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'), g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True) fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func, batchable=True) g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.get_n_repr()[fld] v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
assert th.allclose(v3, v4) assert th.allclose(v3, v4)
...@@ -85,25 +85,25 @@ def test_send_and_recv(): ...@@ -85,25 +85,25 @@ def test_send_and_recv():
# send and recv # send and recv
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld), g.send_and_recv(u, v, fn.copy_src(src=fld),
fn.sum(out=fld), apply_func, batchable=True) fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func, g.send_and_recv(u, v, message_func,
reduce_func, apply_func, batchable=True) reduce_func, apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
# send and recv with edge weights # send and recv with edge weights
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'), g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True) fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'), g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True) fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge, g.send_and_recv(u, v, message_func_edge,
reduce_func, apply_func, batchable=True) reduce_func, apply_func)
v4 = g.get_n_repr()[fld] v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
assert th.allclose(v3, v4) assert th.allclose(v3, v4)
...@@ -127,18 +127,18 @@ def test_update_all_multi_fn(): ...@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
# update all, mix of builtin and UDF # update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func], g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func], [fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# run builtin with single message and reduce # run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True) g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr # 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True) g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
...@@ -147,7 +147,7 @@ def test_update_all_multi_fn(): ...@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces # update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], [fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
...@@ -155,7 +155,7 @@ def test_update_all_multi_fn(): ...@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
assert th.allclose(v1, v3) assert th.allclose(v1, v3)
# run UDF with single message and reduce # run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True) g.update_all(message_func_edge, reduce_func, None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
...@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn(): ...@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v, g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func], [fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func], [fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# run builtin with single message and reduce # run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'), g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr # 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True) g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
...@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn(): ...@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v, g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], [fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
...@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn(): ...@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
# run UDF with single message and reduce # run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge, g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True) reduce_func, None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu + e_uv
def reduce_func(h, msgs):
return h + sum(msgs)
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, edge):
return src[__REPR__] + edge
def reduce_func_hybrid(node, msgs):
return node[__REPR__] + sum(msgs)
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_reduce_func(reduce_func_hybrid)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
test_sendrecv()
test_hybridrepr()
from dgl.graph import DGLGraph
def message_func(src, edge):
return src['h']
def reduce_func(node, msgs):
return {'m' : sum(msgs)}
def apply_func(node):
return {'h' : node['h'] + node['m']}
def message_dict_func(src, edge):
return {'m' : src['h']}
def reduce_dict_func(node, msgs):
return {'m' : sum([msg['m'] for msg in msgs])}
def apply_dict_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i]['h']) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def register1(g):
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_func)
def register2(g):
g.register_message_func(message_dict_func)
g.register_reduce_func(reduce_dict_func)
g.register_apply_node_func(apply_dict_func)
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def _test_multi_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# one-many
g.send(0, [1, 2, 3])
g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.send([6, 7, 8], 9)
g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.send([0, 0, 4, 5], [4, 5, 9, 9])
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def _test_update_routines(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send_and_recv(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.pull(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.push(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def test_sendrecv():
g = generate_graph()
register1(g)
_test_sendrecv(g)
g = generate_graph()
register2(g)
_test_sendrecv(g)
def test_multi_sendrecv():
g = generate_graph()
register1(g)
_test_multi_sendrecv(g)
g = generate_graph()
register2(g)
_test_multi_sendrecv(g)
def test_update_routines():
g = generate_graph()
register1(g)
_test_update_routines(g)
g = generate_graph()
register2(g)
_test_update_routines(g)
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
test_update_routines()
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu
def message_not_called(hu, e_uv):
assert False
return hu
def reduce_not_called(h, msgs):
assert False
return 0
def reduce_func(h, msgs):
return h + sum(msgs)
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
return g
def test_no_msg_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_apply_node_func(lambda h : h + 1)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(1, 9)
g.send(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
g.register_reduce_func(reduce_not_called)
g.recv(9)
def test_pull_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
assert False
return src
def _reduce(node, msgs):
assert False
return node
def _update(node):
return {'h': node['h'] * 2}
g.pull(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
if __name__ == '__main__':
test_no_msg_recv()
test_double_recv()
test_pull_0deg()
import dgl
import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, h=1)
g.add_edge(i, 9, h=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, h=10)
return g
def check(g, h, fld):
nh = [str(g.nodes[i][fld]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, __REPR__=10)
return g
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], __REPR__)
if __name__ == '__main__':
test_copy_src()
test_copy_edge()
test_src_mul_edge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment