Unverified Commit 9219349a authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Use edge update to impl sendto; fix examples with missing reduce func (#22)

parent 68fb5f7e
...@@ -10,8 +10,8 @@ K = 10 ...@@ -10,8 +10,8 @@ K = 10
def message_func(src, dst, edge): def message_func(src, dst, edge):
return src['pv'] / src['deg'] return src['pv'] / src['deg']
def update_func(node, msgs): def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * sum(msgs) pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv} return {'pv' : pv}
def compute_pagerank(g): def compute_pagerank(g):
...@@ -19,6 +19,7 @@ def compute_pagerank(g): ...@@ -19,6 +19,7 @@ def compute_pagerank(g):
print(g.number_of_edges(), g.number_of_nodes()) print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value # init pv value
for n in g.nodes(): for n in g.nodes():
g.node[n]['pv'] = 1 / N g.node[n]['pv'] = 1 / N
......
...@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__" ...@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__" __EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__" __UFUNC__ = "__ufunc__"
__RFUNC__ = "__rfunc__" __RFUNC__ = "__rfunc__"
__READOUT__ = "__readout__"
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -31,10 +32,7 @@ class DGLGraph(DiGraph): ...@@ -31,10 +32,7 @@ class DGLGraph(DiGraph):
""" """
def __init__(self, graph_data=None, **attr): def __init__(self, graph_data=None, **attr):
super(DGLGraph, self).__init__(graph_data, **attr) super(DGLGraph, self).__init__(graph_data, **attr)
self.m_func = None self._glb_func = {}
self.u_func = None
self.e_func = None
self.readout_func = None
def init_reprs(self, h_init=None): def init_reprs(self, h_init=None):
print("[DEPRECATED]: please directly set node attrs " print("[DEPRECATED]: please directly set node attrs "
...@@ -55,12 +53,16 @@ class DGLGraph(DiGraph): ...@@ -55,12 +53,16 @@ class DGLGraph(DiGraph):
assert u in self.nodes assert u in self.nodes
return self.nodes[u][name] return self.nodes[u][name]
def register_message_func(self, message_func, edges='all', batchable=False): def register_message_func(self,
message_func,
edges='all',
batchable=False,
name=__MFUNC__):
"""Register computation on edges. """Register computation on edges.
The message function should be compatible with following signature: The message function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs (node_reprs, node_reprs, edge_reprs) -> msg
It computes the representation of a message It computes the representation of a message
using the representations of the source node, target node and the edge using the representations of the source node, target node and the edge
...@@ -76,6 +78,8 @@ class DGLGraph(DiGraph): ...@@ -76,6 +78,8 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -91,13 +95,15 @@ class DGLGraph(DiGraph): ...@@ -91,13 +95,15 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...] >>> v = [v1, v2, v3, ...]
>>> g.register_message_func(mfunc, (u, v)) >>> g.register_message_func(mfunc, (u, v))
""" """
if edges == 'all': def _msg_edge_func(u, v, e_uv):
self.m_func = message_func return {__MSG__ : message_func(u, v, e_uv)}
else: self.register_edge_func(_msg_edge_func, edges, batchable, name)
for e in edges:
self.edges[e][__MFUNC__] = message_func def register_edge_func(self,
edge_func,
def register_edge_func(self, edge_func, edges='all', batchable=False): edges='all',
batchable=False,
name=__EFUNC__):
"""Register computation on edges. """Register computation on edges.
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -118,6 +124,8 @@ class DGLGraph(DiGraph): ...@@ -118,6 +124,8 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -134,12 +142,16 @@ class DGLGraph(DiGraph): ...@@ -134,12 +142,16 @@ class DGLGraph(DiGraph):
>>> g.register_edge_func(mfunc, (u, v)) >>> g.register_edge_func(mfunc, (u, v))
""" """
if edges == 'all': if edges == 'all':
self.e_func = edge_func self._glb_func[name] = edge_func
else: else:
for e in edges: for e in edges:
self.edges[e][__EFUNC__] = edge_func self.edges[e][name] = edge_func
def register_reduce_func(self, reduce_func, nodes='all', batchable=False): def register_reduce_func(self,
reduce_func,
nodes='all',
batchable=False,
name=__RFUNC__):
"""Register message reduce function on incoming edges. """Register message reduce function on incoming edges.
The reduce function should be compatible with following signature: The reduce function should be compatible with following signature:
...@@ -163,6 +175,8 @@ class DGLGraph(DiGraph): ...@@ -163,6 +175,8 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided reduce function allows batch computing. Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -187,12 +201,16 @@ class DGLGraph(DiGraph): ...@@ -187,12 +201,16 @@ class DGLGraph(DiGraph):
raise NotImplementedError( raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func) "Built-in function %s not implemented" % reduce_func)
if nodes == 'all': if nodes == 'all':
self.r_func = reduce_func self._glb_func[name] = reduce_func
else: else:
for n in nodes: for n in nodes:
self.nodes[n][__RFUNC__] = reduce_func self.nodes[n][name] = reduce_func
def register_update_func(self, update_func, nodes='all', batchable=False): def register_update_func(self,
update_func,
nodes='all',
batchable=False,
name=__UFUNC__):
"""Register computation on nodes. """Register computation on nodes.
The update function should be compatible with following signature: The update function should be compatible with following signature:
...@@ -213,6 +231,8 @@ class DGLGraph(DiGraph): ...@@ -213,6 +231,8 @@ class DGLGraph(DiGraph):
supported. supported.
batchable : bool batchable : bool
Whether the provided update function allows batch computing. Whether the provided update function allows batch computing.
name : str
The name of the function.
Examples Examples
-------- --------
...@@ -228,12 +248,12 @@ class DGLGraph(DiGraph): ...@@ -228,12 +248,12 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc, u) >>> g.register_update_func(ufunc, u)
""" """
if nodes == 'all': if nodes == 'all':
self.u_func = update_func self._glb_func[name] = update_func
else: else:
for n in nodes: for n in nodes:
self.nodes[n][__UFUNC__] = update_func self.nodes[n][name] = update_func
def register_readout_func(self, readout_func): def register_readout_func(self, readout_func, name=__READOUT__):
"""Register computation on the whole graph. """Register computation on the whole graph.
The readout_func should be compatible with following signature: The readout_func should be compatible with following signature:
...@@ -251,14 +271,20 @@ class DGLGraph(DiGraph): ...@@ -251,14 +271,20 @@ class DGLGraph(DiGraph):
---------- ----------
readout_func : callable readout_func : callable
The readout function. The readout function.
name : str
The name of the function.
See Also See Also
-------- --------
readout readout
""" """
self.readout_func = readout_func self._glb_func[name] = readout_func
def readout(self, nodes='all', edges='all', **kwargs): def readout(self,
nodes='all',
edges='all',
name=__READOUT__,
**kwargs):
"""Trigger the readout function on the specified nodes/edges. """Trigger the readout function on the specified nodes/edges.
Parameters Parameters
...@@ -267,19 +293,21 @@ class DGLGraph(DiGraph): ...@@ -267,19 +293,21 @@ class DGLGraph(DiGraph):
The nodes to get reprs from. The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from. The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional kwargs : keyword arguments, optional
Arguments for the readout function. Arguments for the readout function.
""" """
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges) edges = self._edges_or_all(edges)
assert self.readout_func is not None, \ assert name in self._glb_func, \
"Readout function is not registered." "Readout function \"%s\" has not been registered." % name
# TODO(minjie): tensorize following loop. # TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes] nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges] estates = [self.edges[e] for e in edges]
return self.readout_func(nstates, estates, **kwargs) return self._glb_func[name](nstates, estates, **kwargs)
def sendto(self, u, v): def sendto(self, u, v, name=__MFUNC__):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
Parameters Parameters
...@@ -288,16 +316,12 @@ class DGLGraph(DiGraph): ...@@ -288,16 +316,12 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
name : str
The name of the function.
""" """
# TODO(minjie): tensorize the loop. self.update_edge(u, v, name)
for uu, vv in utils.edge_iter(u, v):
f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func)
assert f_msg is not None, \
"message function not registered for edge (%s->%s)" % (uu, vv)
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m
def update_edge(self, u, v): def update_edge(self, u, v, name=__EFUNC__):
"""Update representation on edge u->v """Update representation on edge u->v
Parameters Parameters
...@@ -306,16 +330,19 @@ class DGLGraph(DiGraph): ...@@ -306,16 +330,19 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
name : str
The name of the function.
""" """
# TODO(minjie): tensorize the loop. # TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(__EFUNC__, self.m_func) f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \ assert f_edge is not None, \
"edge function not registered for edge (%s->%s)" % (uu, vv) "edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv]) m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__E_REPR__] = m self.edges[uu, vv].update(m)
def recvfrom(self, u, preds=None): def recvfrom(self, u, preds=None, rname=__RFUNC__, uname=__UFUNC__):
"""Trigger the update function on node u. """Trigger the update function on node u.
It computes the new node state using the messages and edge It computes the new node state using the messages and edge
...@@ -330,9 +357,15 @@ class DGLGraph(DiGraph): ...@@ -330,9 +357,15 @@ class DGLGraph(DiGraph):
preds : container preds : container
Nodes with pre-computed messages to u. Default is all Nodes with pre-computed messages to u. Default is all
the predecessors. the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
""" """
u_is_container = isinstance(u, list) u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor) u_is_tensor = isinstance(u, Tensor)
rfunc = self._glb_func.get(rname)
ufunc = self._glb_func.get(uname)
# TODO(minjie): tensorize the loop. # TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
if preds is None: if preds is None:
...@@ -342,12 +375,12 @@ class DGLGraph(DiGraph): ...@@ -342,12 +375,12 @@ class DGLGraph(DiGraph):
else: else:
v = preds v = preds
# TODO(minjie): tensorize the message batching # TODO(minjie): tensorize the message batching
m = [self.edges[vv, uu][__MSG__] for vv in v] f_reduce = self.nodes[uu].get(rname, rfunc)
f_reduce = self.nodes[uu].get(__RFUNC__, self.r_func)
assert f_reduce is not None, \ assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu "Reduce function not registered for node %s" % uu
m = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_reduced_repr = f_reduce(m) msgs_reduced_repr = f_reduce(m)
f_update = self.nodes[uu].get(__UFUNC__, self.u_func) f_update = self.nodes[uu].get(uname, ufunc)
assert f_update is not None, \ assert f_update is not None, \
"Update function not registered for node %s" % uu "Update function not registered for node %s" % uu
self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr)) self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr))
...@@ -413,9 +446,6 @@ class DGLGraph(DiGraph): ...@@ -413,9 +446,6 @@ class DGLGraph(DiGraph):
v = [vv for _, vv in self.edges] v = [vv for _, vv in self.edges]
self.sendto(u, v) self.sendto(u, v)
self.recvfrom(list(self.nodes())) self.recvfrom(list(self.nodes()))
# TODO(zz): this is a hack
if self.e_func:
self.update_edge(u, v)
def propagate(self, iterator='bfs', **kwargs): def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
......
...@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph ...@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph
def message_func(src, dst, edge): def message_func(src, dst, edge):
return src['h'] return src['h']
def update_func(node, msgs): def update_func(node, accum):
m = sum(msgs) return {'h' : node['h'] + accum}
return {'h' : node['h'] + m}
def generate_graph(): def generate_graph():
g = DGLGraph() g = DGLGraph()
...@@ -29,6 +28,7 @@ def test_sendrecv(): ...@@ -29,6 +28,7 @@ def test_sendrecv():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1) g.sendto(0, 1)
g.recvfrom(1, [0]) g.recvfrom(1, [0])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
...@@ -42,6 +42,7 @@ def test_multi_sendrecv(): ...@@ -42,6 +42,7 @@ def test_multi_sendrecv():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum')
# one-many # one-many
g.sendto(0, [1, 2, 3]) g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]]) g.recvfrom([1, 2, 3], [[0], [0], [0]])
...@@ -60,6 +61,7 @@ def test_update_routines(): ...@@ -60,6 +61,7 @@ def test_update_routines():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum')
g.update_by_edge(0, 1) g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9) g.update_to(9)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment