"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0a42d863b740e3e13d79ee081d3792a4a04aed87"
Unverified Commit d24ccfdd authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Anonymous node/edge repr (#24)

* Remove name args in register/sendto/update_edge/recvfrom

* Support anonymous repr

* Test anonymous edge repr

* Test node states with both anonymous and explicit reprs
parent 345a472d
...@@ -10,8 +10,7 @@ from dgl.backend import Tensor ...@@ -10,8 +10,7 @@ from dgl.backend import Tensor
import dgl.utils as utils import dgl.utils as utils
__MSG__ = "__msg__" __MSG__ = "__msg__"
__E_REPR__ = "__e_repr__" __REPR__ = "__repr__"
__N_REPR__ = "__n_repr__"
__MFUNC__ = "__mfunc__" __MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__" __EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__" __UFUNC__ = "__ufunc__"
...@@ -40,24 +39,27 @@ class DGLGraph(DiGraph): ...@@ -40,24 +39,27 @@ class DGLGraph(DiGraph):
for n in self.nodes: for n in self.nodes:
self.set_repr(n, h_init) self.set_repr(n, h_init)
def set_repr(self, u, h_u, name=__N_REPR__): def set_n_repr(self, u, h_u):
print("[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val).")
assert u in self.nodes assert u in self.nodes
kwarg = {name: h_u} kwarg = {__REPR__: h_u}
self.add_node(u, **kwarg) self.add_node(u, **kwarg)
def get_repr(self, u, name=__N_REPR__): def get_n_repr(self, u):
print("[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x']).")
assert u in self.nodes assert u in self.nodes
return self.nodes[u][name] return self.nodes[u][__REPR__]
def set_e_repr(self, u, v, h_uv):
assert (u, v) in self.edges
self.edges[u, v][__REPR__] = h_uv
def get_e_repr(self, u, v):
assert (u, v) in self.edges
return self.edges[u, v][__REPR__]
def register_message_func(self, def register_message_func(self,
message_func, message_func,
edges='all', edges='all',
batchable=False, batchable=False):
name=__MFUNC__):
"""Register computation on edges. """Register computation on edges.
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -78,8 +80,6 @@ class DGLGraph(DiGraph): ...@@ -78,8 +80,6 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -97,13 +97,12 @@ class DGLGraph(DiGraph): ...@@ -97,13 +97,12 @@ class DGLGraph(DiGraph):
""" """
def _msg_edge_func(u, v, e_uv): def _msg_edge_func(u, v, e_uv):
return {__MSG__ : message_func(u, v, e_uv)} return {__MSG__ : message_func(u, v, e_uv)}
self.register_edge_func(_msg_edge_func, edges, batchable, name) self._internal_register_edge(__MFUNC__, _msg_edge_func, edges, batchable)
def register_edge_func(self, def register_edge_func(self,
edge_func, edge_func,
edges='all', edges='all',
batchable=False, batchable=False):
name=__EFUNC__):
"""Register computation on edges. """Register computation on edges.
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -124,8 +123,6 @@ class DGLGraph(DiGraph): ...@@ -124,8 +123,6 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -141,17 +138,12 @@ class DGLGraph(DiGraph): ...@@ -141,17 +138,12 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...] >>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v)) >>> g.register_edge_func(mfunc, (u, v))
""" """
if edges == 'all': self._internal_register_edge(__EFUNC__, edge_func, edges, batchable)
self._glb_func[name] = edge_func
else:
for e in edges:
self.edges[e][name] = edge_func
def register_reduce_func(self, def register_reduce_func(self,
reduce_func, reduce_func,
nodes='all', nodes='all',
batchable=False, batchable=False):
name=__RFUNC__):
"""Register message reduce function on incoming edges. """Register message reduce function on incoming edges.
The reduce function should be compatible with following signature: The reduce function should be compatible with following signature:
...@@ -175,8 +167,6 @@ class DGLGraph(DiGraph): ...@@ -175,8 +167,6 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided reduce function allows batch computing. Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -200,17 +190,12 @@ class DGLGraph(DiGraph): ...@@ -200,17 +190,12 @@ class DGLGraph(DiGraph):
else: else:
raise NotImplementedError( raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func) "Built-in function %s not implemented" % reduce_func)
if nodes == 'all': self._internal_register_node(__RFUNC__, reduce_func, nodes, batchable)
self._glb_func[name] = reduce_func
else:
for n in nodes:
self.nodes[n][name] = reduce_func
def register_update_func(self, def register_update_func(self,
update_func, update_func,
nodes='all', nodes='all',
batchable=False, batchable=False):
name=__UFUNC__):
"""Register computation on nodes. """Register computation on nodes.
The update function should be compatible with following signature: The update function should be compatible with following signature:
...@@ -247,13 +232,9 @@ class DGLGraph(DiGraph): ...@@ -247,13 +232,9 @@ class DGLGraph(DiGraph):
>>> u = [u1, u2, u3, ...] >>> u = [u1, u2, u3, ...]
>>> g.register_update_func(ufunc, u) >>> g.register_update_func(ufunc, u)
""" """
if nodes == 'all': self._internal_register_node(__UFUNC__, update_func, nodes, batchable)
self._glb_func[name] = update_func
else:
for n in nodes:
self.nodes[n][name] = update_func
def register_readout_func(self, readout_func, name=__READOUT__): def register_readout_func(self, readout_func):
"""Register computation on the whole graph. """Register computation on the whole graph.
The readout_func should be compatible with following signature: The readout_func should be compatible with following signature:
...@@ -271,19 +252,16 @@ class DGLGraph(DiGraph): ...@@ -271,19 +252,16 @@ class DGLGraph(DiGraph):
---------- ----------
readout_func : callable readout_func : callable
The readout function. The readout function.
name : str
The name of the function.
See Also See Also
-------- --------
readout readout
""" """
self._glb_func[name] = readout_func self._glb_func[__READOUT__] = readout_func
def readout(self, def readout(self,
nodes='all', nodes='all',
edges='all', edges='all',
name=__READOUT__,
**kwargs): **kwargs):
"""Trigger the readout function on the specified nodes/edges. """Trigger the readout function on the specified nodes/edges.
...@@ -293,21 +271,19 @@ class DGLGraph(DiGraph): ...@@ -293,21 +271,19 @@ class DGLGraph(DiGraph):
The nodes to get reprs from. The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from. The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional kwargs : keyword arguments, optional
Arguments for the readout function. Arguments for the readout function.
""" """
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
assert name in self._glb_func, \ assert __READOUT__ in self._glb_func, \
"Readout function \"%s\" has not been registered." % name "Readout function has not been registered."
# TODO(minjie): tensorize following loop. # TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes] nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges] estates = [self.edges[e] for e in edges]
return self._glb_func[name](nstates, estates, **kwargs) return self._glb_func[__READOUT__](nstates, estates, **kwargs)
def sendto(self, u, v, name=__MFUNC__): def sendto(self, u, v):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
Parameters Parameters
...@@ -316,12 +292,10 @@ class DGLGraph(DiGraph): ...@@ -316,12 +292,10 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
name : str
The name of the function.
""" """
self.update_edge(u, v, name) self._internal_trigger_edges(u, v, __MFUNC__)
def update_edge(self, u, v, name=__EFUNC__): def update_edge(self, u, v):
"""Update representation on edge u->v """Update representation on edge u->v
Parameters Parameters
...@@ -330,19 +304,10 @@ class DGLGraph(DiGraph): ...@@ -330,19 +304,10 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
name : str
The name of the function.
""" """
# TODO(minjie): tensorize the loop. self._internal_trigger_edges(u, v, __EFUNC__)
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv].update(m)
def recvfrom(self, u, preds=None, rname=__RFUNC__, uname=__UFUNC__): def recvfrom(self, u, preds=None):
"""Trigger the update function on node u. """Trigger the update function on node u.
It computes the new node state using the messages and edge It computes the new node state using the messages and edge
...@@ -357,15 +322,11 @@ class DGLGraph(DiGraph): ...@@ -357,15 +322,11 @@ class DGLGraph(DiGraph):
preds : container preds : container
Nodes with pre-computed messages to u. Default is all Nodes with pre-computed messages to u. Default is all
the predecessors. the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
""" """
u_is_container = isinstance(u, list) u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor) u_is_tensor = isinstance(u, Tensor)
rfunc = self._glb_func.get(rname) rfunc = self._glb_func.get(__RFUNC__)
ufunc = self._glb_func.get(uname) ufunc = self._glb_func.get(__UFUNC__)
# TODO(minjie): tensorize the loop. # TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
if preds is None: if preds is None:
...@@ -375,15 +336,18 @@ class DGLGraph(DiGraph): ...@@ -375,15 +336,18 @@ class DGLGraph(DiGraph):
else: else:
v = preds v = preds
# TODO(minjie): tensorize the message batching # TODO(minjie): tensorize the message batching
f_reduce = self.nodes[uu].get(rname, rfunc) # reduce phase
f_reduce = self.nodes[uu].get(__RFUNC__, rfunc)
assert f_reduce is not None, \ assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu "Reduce function not registered for node %s" % uu
m = [self.edges[vv, uu][__MSG__] for vv in v] msgs_batch = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_reduced_repr = f_reduce(m) msgs_reduced = f_reduce(msgs_batch)
f_update = self.nodes[uu].get(uname, ufunc) # update phase
f_update = self.nodes[uu].get(__UFUNC__, ufunc)
assert f_update is not None, \ assert f_update is not None, \
"Update function not registered for node %s" % uu "Update function not registered for node %s" % uu
self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr)) ret = f_update(self._get_repr(self.nodes[uu]), msgs_reduced)
self._set_repr(self.nodes[uu], ret)
def update_by_edge(self, u, v): def update_by_edge(self, u, v):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
...@@ -484,3 +448,45 @@ class DGLGraph(DiGraph): ...@@ -484,3 +448,45 @@ class DGLGraph(DiGraph):
def _edges_or_all(self, edges='all'): def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges return self.edges() if edges == 'all' else edges
def _get_repr(self, states):
if len(states) == 1 and __REPR__ in states:
return states[__REPR__]
else:
return states
def _set_repr(self, states, val):
if isinstance(val, dict):
states.update(val)
else:
states[__REPR__] = val
def _internal_register_node(self, name, func, nodes, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group nodes based on their registered func
if nodes == 'all':
self._glb_func[name] = func
else:
for n in nodes:
self.nodes[n][name] = func
def _internal_register_edge(self, name, func, edges, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group edges based on their registered func
if edges == 'all':
self._glb_func[name] = func
else:
for e in edges:
self.edges[e][name] = func
def _internal_trigger_edges(self, u, v, name):
# TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
ret = f_edge(self._get_repr(self.nodes[uu]),
self._get_repr(self.nodes[vv]),
self._get_repr(self.edges[uu, vv]))
self._set_repr(self.edges[uu, vv], ret)
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
return hu + e_uv
def update_func(h, accum):
return h + accum
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, dst, edge):
return src[__REPR__] + edge
def update_func_hybrid(node, accum):
return node[__REPR__] + accum
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_update_func(update_func_hybrid)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
test_sendrecv()
test_hybridrepr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment