Unverified Commit 9219349a authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Use edge update to impl sendto; fix examples with missing reduce func (#22)

parent 68fb5f7e
......@@ -10,8 +10,8 @@ K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, msgs):
pv = (1 - DAMP) / N + DAMP * sum(msgs)
def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv}
def compute_pagerank(g):
......@@ -19,6 +19,7 @@ def compute_pagerank(g):
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
......
......@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__"
__RFUNC__ = "__rfunc__"
__READOUT__ = "__readout__"
class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs.
......@@ -31,10 +32,7 @@ class DGLGraph(DiGraph):
"""
def __init__(self, graph_data=None, **attr):
super(DGLGraph, self).__init__(graph_data, **attr)
self.m_func = None
self.u_func = None
self.e_func = None
self.readout_func = None
self._glb_func = {}
def init_reprs(self, h_init=None):
print("[DEPRECATED]: please directly set node attrs "
......@@ -55,12 +53,16 @@ class DGLGraph(DiGraph):
assert u in self.nodes
return self.nodes[u][name]
def register_message_func(self, message_func, edges='all', batchable=False):
def register_message_func(self,
message_func,
edges='all',
batchable=False,
name=__MFUNC__):
"""Register computation on edges.
The message function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
(node_reprs, node_reprs, edge_reprs) -> msg
It computes the representation of a message
using the representations of the source node, target node and the edge
......@@ -76,6 +78,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -91,13 +95,15 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...]
>>> g.register_message_func(mfunc, (u, v))
"""
if edges == 'all':
self.m_func = message_func
else:
for e in edges:
self.edges[e][__MFUNC__] = message_func
def register_edge_func(self, edge_func, edges='all', batchable=False):
def _msg_edge_func(u, v, e_uv):
return {__MSG__ : message_func(u, v, e_uv)}
self.register_edge_func(_msg_edge_func, edges, batchable, name)
def register_edge_func(self,
edge_func,
edges='all',
batchable=False,
name=__EFUNC__):
"""Register computation on edges.
The edge function should be compatible with following signature:
......@@ -118,6 +124,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -134,12 +142,16 @@ class DGLGraph(DiGraph):
>>> g.register_edge_func(mfunc, (u, v))
"""
if edges == 'all':
self.e_func = edge_func
self._glb_func[name] = edge_func
else:
for e in edges:
self.edges[e][__EFUNC__] = edge_func
self.edges[e][name] = edge_func
def register_reduce_func(self, reduce_func, nodes='all', batchable=False):
def register_reduce_func(self,
reduce_func,
nodes='all',
batchable=False,
name=__RFUNC__):
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
......@@ -163,6 +175,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -187,12 +201,16 @@ class DGLGraph(DiGraph):
raise NotImplementedError(
"Built-in function %s not implemented" % reduce_func)
if nodes == 'all':
self.r_func = reduce_func
self._glb_func[name] = reduce_func
else:
for n in nodes:
self.nodes[n][__RFUNC__] = reduce_func
self.nodes[n][name] = reduce_func
def register_update_func(self, update_func, nodes='all', batchable=False):
def register_update_func(self,
update_func,
nodes='all',
batchable=False,
name=__UFUNC__):
"""Register computation on nodes.
The update function should be compatible with following signature:
......@@ -213,6 +231,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided update function allows batch computing.
name : str
The name of the function.
Examples
--------
......@@ -228,12 +248,12 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc, u)
"""
if nodes == 'all':
self.u_func = update_func
self._glb_func[name] = update_func
else:
for n in nodes:
self.nodes[n][__UFUNC__] = update_func
self.nodes[n][name] = update_func
def register_readout_func(self, readout_func):
def register_readout_func(self, readout_func, name=__READOUT__):
"""Register computation on the whole graph.
The readout_func should be compatible with following signature:
......@@ -251,14 +271,20 @@ class DGLGraph(DiGraph):
----------
readout_func : callable
The readout function.
name : str
The name of the function.
See Also
--------
readout
"""
self.readout_func = readout_func
self._glb_func[name] = readout_func
def readout(self, nodes='all', edges='all', **kwargs):
def readout(self,
nodes='all',
edges='all',
name=__READOUT__,
**kwargs):
"""Trigger the readout function on the specified nodes/edges.
Parameters
......@@ -267,19 +293,21 @@ class DGLGraph(DiGraph):
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional
Arguments for the readout function.
"""
nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
assert self.readout_func is not None, \
"Readout function is not registered."
assert name in self._glb_func, \
"Readout function \"%s\" has not been registered." % name
# TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return self.readout_func(nstates, estates, **kwargs)
return self._glb_func[name](nstates, estates, **kwargs)
def sendto(self, u, v):
def sendto(self, u, v, name=__MFUNC__):
"""Trigger the message function on edge u->v
Parameters
......@@ -288,16 +316,12 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
# TODO(minjie): tensorize the loop.
for uu, vv in utils.edge_iter(u, v):
f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func)
assert f_msg is not None, \
"message function not registered for edge (%s->%s)" % (uu, vv)
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m
self.update_edge(u, v, name)
def update_edge(self, u, v):
def update_edge(self, u, v, name=__EFUNC__):
"""Update representation on edge u->v
Parameters
......@@ -306,16 +330,19 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
# TODO(minjie): tensorize the loop.
efunc = self._glb_func.get(name)
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(__EFUNC__, self.m_func)
f_edge = self.edges[uu, vv].get(name, efunc)
assert f_edge is not None, \
"edge function not registered for edge (%s->%s)" % (uu, vv)
"edge function \"%s\" not registered for edge (%s->%s)" % (name, uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__E_REPR__] = m
self.edges[uu, vv].update(m)
def recvfrom(self, u, preds=None):
def recvfrom(self, u, preds=None, rname=__RFUNC__, uname=__UFUNC__):
"""Trigger the update function on node u.
It computes the new node state using the messages and edge
......@@ -330,9 +357,15 @@ class DGLGraph(DiGraph):
preds : container
Nodes with pre-computed messages to u. Default is all
the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
"""
u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor)
rfunc = self._glb_func.get(rname)
ufunc = self._glb_func.get(uname)
# TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)):
if preds is None:
......@@ -342,12 +375,12 @@ class DGLGraph(DiGraph):
else:
v = preds
# TODO(minjie): tensorize the message batching
m = [self.edges[vv, uu][__MSG__] for vv in v]
f_reduce = self.nodes[uu].get(__RFUNC__, self.r_func)
f_reduce = self.nodes[uu].get(rname, rfunc)
assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu
m = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_reduced_repr = f_reduce(m)
f_update = self.nodes[uu].get(__UFUNC__, self.u_func)
f_update = self.nodes[uu].get(uname, ufunc)
assert f_update is not None, \
"Update function not registered for node %s" % uu
self.node[uu].update(f_update(self.nodes[uu], msgs_reduced_repr))
......@@ -413,9 +446,6 @@ class DGLGraph(DiGraph):
v = [vv for _, vv in self.edges]
self.sendto(u, v)
self.recvfrom(list(self.nodes()))
# TODO(zz): this is a hack
if self.e_func:
self.update_edge(u, v)
def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator.
......
......@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph
def message_func(src, dst, edge):
return src['h']
def update_func(node, msgs):
m = sum(msgs)
return {'h' : node['h'] + m}
def update_func(node, accum):
return {'h' : node['h'] + accum}
def generate_graph():
g = DGLGraph()
......@@ -29,6 +28,7 @@ def test_sendrecv():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
......@@ -42,6 +42,7 @@ def test_multi_sendrecv():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# one-many
g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]])
......@@ -60,6 +61,7 @@ def test_update_routines():
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment