Unverified Commit d24ccfdd authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Anonymous node/edge repr (#24)

* Remove name args in register/sendto/update_edge/recvfrom

* Support anonymous repr

* Test anonymous edge repr

* Test node states with both anonymous and explicit reprs
parent 345a472d
......@@ -10,8 +10,7 @@ from dgl.backend import Tensor
import dgl.utils as utils
__MSG__ = "__msg__"
__E_REPR__ = "__e_repr__"
__N_REPR__ = "__n_repr__"
__REPR__ = "__repr__"
__MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__"
......@@ -40,24 +39,27 @@ class DGLGraph(DiGraph):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name=__N_REPR__):
print("[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val).")
def set_n_repr(self, u, h_u):
assert u in self.nodes
kwarg = {name: h_u}
kwarg = {__REPR__: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name=__N_REPR__):
print("[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x']).")
def get_n_repr(self, u):
assert u in self.nodes
return self.nodes[u][name]
return self.nodes[u][__REPR__]
def set_e_repr(self, u, v, h_uv):
assert (u, v) in self.edges
self.edges[u, v][__REPR__] = h_uv
def get_e_repr(self, u, v):
assert (u, v) in self.edges
return self.edges[u, v][__REPR__]
def register_message_func(self,
message_func,
edges='all',
batchable=False,
name=__MFUNC__):
batchable=False):
"""Register computation on edges.
The message function should be compatible with following signature:
......@@ -78,8 +80,6 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -97,13 +97,12 @@ class DGLGraph(DiGraph):
"""
def _msg_edge_func(u, v, e_uv):
return {__MSG__ : message_func(u, v, e_uv)}
self.register_edge_func(_msg_edge_func, edges, batchable, name)
self._internal_register_edge(__MFUNC__, _msg_edge_func, edges, batchable)
def register_edge_func(self,
edge_func,
edges='all',
batchable=False,
name=__EFUNC__):
batchable=False):
"""Register computation on edges.
The edge function should be compatible with following signature:
......@@ -124,8 +123,6 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -141,17 +138,12 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
if edges == 'all':
self._glb_func[name] = edge_func
else:
for e in edges:
self.edges[e][name] = edge_func
self._internal_register_edge(__EFUNC__, edge_func, edges, batchable)
def register_reduce_func(self,
reduce_func,
nodes='all',
batchable=False,
name=__RFUNC__):
batchable=False):
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
......@@ -175,8 +167,6 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -200,17 +190,12 @@ class DGLGraph(DiGraph):
else:
raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func)
if nodes == 'all':
self._glb_func[name] = reduce_func
else:
for n in nodes:
self.nodes[n][name] = reduce_func
self._internal_register_node(__RFUNC__, reduce_func, nodes, batchable)
def register_update_func(self,
update_func,
nodes='all',
batchable=False,
name=__UFUNC__):
batchable=False):
"""Register computation on nodes.
The update function should be compatible with following signature:
......@@ -247,13 +232,9 @@ class DGLGraph(DiGraph):
>>> u = [u1, u2, u3, ...]
>>> g.register_update_func(ufunc, u)
"""
if nodes == 'all':
self._glb_func[name] = update_func
else:
for n in nodes:
self.nodes[n][name] = update_func
self._internal_register_node(__UFUNC__, update_func, nodes, batchable)
def register_readout_func(self, readout_func, name=__READOUT__):
def register_readout_func(self, readout_func):
"""Register computation on the whole graph.
The readout_func should be compatible with following signature:
......@@ -271,19 +252,16 @@ class DGLGraph(DiGraph):
----------
readout_func : callable
The readout function.
name : str
The name of the function.
See Also
--------
readout
"""
self._glb_func[name] = readout_func
self._glb_func[__READOUT__] = readout_func
def readout(self,
nodes='all',
edges='all',
name=__READOUT__,
**kwargs):
"""Trigger the readout function on the specified nodes/edges.
......@@ -293,21 +271,19 @@ class DGLGraph(DiGraph):
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional
Arguments for the readout function.
"""
nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
assert name in self._glb_func, \
"Readout function \"%s\" has not been registered." % name
assert __READOUT__ in self._glb_func, \
"Readout function has not been registered."
# TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return self._glb_func[name](nstates, estates, **kwargs)
return self._glb_func[__READOUT__](nstates, estates, **kwargs)
def sendto(self, u, v, name=__MFUNC__):
def sendto(self, u, v):
"""Trigger the message function on edge u->v
Parameters
......@@ -316,12 +292,10 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
self.update_edge(u, v, name)
self._internal_trigger_edges(u, v, __MFUNC__)
def update_edge(self, u, v, name=__EFUNC__):
def update_edge(self, u, v):
"""Update representation on edge u->v
Parameters
......@@ -330,19 +304,10 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
# TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv].update(m)
self._internal_trigger_edges(u, v, __EFUNC__)
def recvfrom(self, u, preds=None, rname=__RFUNC__, uname=__UFUNC__):
def recvfrom(self, u, preds=None):
"""Trigger the update function on node u.
It computes the new node state using the messages and edge
......@@ -357,15 +322,11 @@ class DGLGraph(DiGraph):
preds : container
Nodes with pre-computed messages to u. Default is all
the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
"""
u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor)
rfunc = self._glb_func.get(rname)
ufunc = self._glb_func.get(uname)
rfunc = self._glb_func.get(__RFUNC__)
ufunc = self._glb_func.get(__UFUNC__)
# TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)):
if preds is None:
......@@ -375,15 +336,18 @@ class DGLGraph(DiGraph):
else:
v = preds
# TODO(minjie): tensorize the message batching
f_reduce = self.nodes[uu].get(rname, rfunc)
# reduce phase
f_reduce = self.nodes[uu].get(__RFUNC__, rfunc)
assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu
m = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_reduced_repr = f_reduce(m)
f_update = self.nodes[uu].get(uname, ufunc)
msgs_batch = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_reduced = f_reduce(msgs_batch)
# update phase
f_update = self.nodes[uu].get(__UFUNC__, ufunc)
assert f_update is not None, \
"Update function not registered for node %s" % uu
self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr))
ret = f_update(self._get_repr(self.nodes[uu]), msgs_reduced)
self._set_repr(self.nodes[uu], ret)
def update_by_edge(self, u, v):
"""Trigger the message function on u->v and update v.
......@@ -484,3 +448,45 @@ class DGLGraph(DiGraph):
def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges
def _get_repr(self, states):
if len(states) == 1 and __REPR__ in states:
return states[__REPR__]
else:
return states
def _set_repr(self, states, val):
if isinstance(val, dict):
states.update(val)
else:
states[__REPR__] = val
def _internal_register_node(self, name, func, nodes, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group nodes based on their registered func
if nodes == 'all':
self._glb_func[name] = func
else:
for n in nodes:
self.nodes[n][name] = func
def _internal_register_edge(self, name, func, edges, batchable):
# TODO(minjie): handle batchable
# TODO(minjie): group edges based on their registered func
if edges == 'all':
self._glb_func[name] = func
else:
for e in edges:
self.edges[e][name] = func
def _internal_trigger_edges(self, u, v, name):
# TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
ret = f_edge(self._get_repr(self.nodes[uu]),
self._get_repr(self.nodes[vv]),
self._get_repr(self.edges[uu, vv]))
self._set_repr(self.edges[uu, vv], ret)
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
return hu + e_uv
def update_func(h, accum):
return h + accum
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, dst, edge):
return src[__REPR__] + edge
def update_func_hybrid(node, accum):
return node[__REPR__] + accum
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_update_func(update_func_hybrid)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
test_sendrecv()
test_hybridrepr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment