Commit f310e586 authored by Minjie Wang's avatar Minjie Wang
Browse files

Add test/example; fix bug in graph.py

parent 15a2c22c
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, msgs):
pv = (1 - DAMP) / N + DAMP * sum(msgs)
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
""" """
from collections import defaultdict
import networkx as nx import networkx as nx
from networkx.classes.digraph import DiGraph from networkx.classes.digraph import DiGraph
...@@ -170,7 +171,7 @@ class DGLGraph(DiGraph): ...@@ -170,7 +171,7 @@ class DGLGraph(DiGraph):
""" """
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
edges = self._nodes_or_all(nodes) edges = self._nodes_or_all(nodes)
assert self.readout_func is not None, assert self.readout_func is not None, \
"Readout function is not registered." "Readout function is not registered."
# TODO(minjie): tensorize following loop. # TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes] nstates = [self.nodes[n] for n in nodes]
...@@ -190,7 +191,7 @@ class DGLGraph(DiGraph): ...@@ -190,7 +191,7 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the loop. # TODO(minjie): tensorize the loop.
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func) f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func)
assert f_msg is not None, assert f_msg is not None, \
"message function not registered for edge (%s->%s)" % (uu, vv) "message function not registered for edge (%s->%s)" % (uu, vv)
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv]) m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m self.edges[uu, vv][__MSG__] = m
...@@ -224,9 +225,9 @@ class DGLGraph(DiGraph): ...@@ -224,9 +225,9 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the message batching # TODO(minjie): tensorize the message batching
m = [self.edges[vv, uu][__MSG__] for vv in v] m = [self.edges[vv, uu][__MSG__] for vv in v]
f_update = self.nodes[uu].get(__UFUNC__, self.u_func) f_update = self.nodes[uu].get(__UFUNC__, self.u_func)
assert f_update is not None, assert f_update is not None, \
"Update function not registered for node %s" % uu "Update function not registered for node %s" % uu
self.nodes[uu] = f_update(self.nodes[uu], m) self.node[uu].update(f_update(self.nodes[uu], m))
def update_by_edge(self, u, v): def update_by_edge(self, u, v):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
...@@ -283,9 +284,9 @@ class DGLGraph(DiGraph): ...@@ -283,9 +284,9 @@ class DGLGraph(DiGraph):
u = [uu for uu, _ in self.edges] u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges] v = [vv for _, vv in self.edges]
self.sendto(u, v) self.sendto(u, v)
self.recvfrom(v) self.recvfrom(list(self.nodes()))
def propagate(self, iterator='bfs'): def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
A convenient function for passing messages and updating A convenient function for passing messages and updating
...@@ -299,6 +300,8 @@ class DGLGraph(DiGraph): ...@@ -299,6 +300,8 @@ class DGLGraph(DiGraph):
---------- ----------
iterator : str or generator of steps. iterator : str or generator of steps.
The iterator of the graph. The iterator of the graph.
kwargs : keyword arguments, optional
Arguments for pre-defined iterators.
""" """
if isinstance(iterator, str): if isinstance(iterator, str):
# TODO Call pre-defined routine to unroll the computation. # TODO Call pre-defined routine to unroll the computation.
......
from dgl.graph import DGLGraph
def message_func(src, dst, edge):
return src['h']
def update_func(node, msgs):
m = sum(msgs)
return {'h' : node['h'] + m}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i]['h']) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
# one-many
g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.sendto([6, 7, 8], 9)
g.recvfrom(9, [6, 7, 8])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
g.recvfrom([4, 5, 9], [[0], [0], [4, 5]])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.update_from(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
test_update_routines()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment