Commit f310e586 authored by Minjie Wang's avatar Minjie Wang
Browse files

Add test/example; fix bug in graph.py

parent 15a2c22c
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, msgs):
pv = (1 - DAMP) / N + DAMP * sum(msgs)
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
"""Base graph class specialized for neural networks on graphs.
"""
from collections import defaultdict
import networkx as nx
from networkx.classes.digraph import DiGraph
......@@ -170,7 +171,7 @@ class DGLGraph(DiGraph):
"""
nodes = self._nodes_or_all(nodes)
edges = self._nodes_or_all(nodes)
assert self.readout_func is not None,
assert self.readout_func is not None, \
"Readout function is not registered."
# TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes]
......@@ -190,7 +191,7 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the loop.
for uu, vv in utils.edge_iter(u, v):
f_msg = self.edges[uu, vv].get(__MFUNC__, self.m_func)
assert f_msg is not None,
assert f_msg is not None, \
"message function not registered for edge (%s->%s)" % (uu, vv)
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m
......@@ -224,9 +225,9 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the message batching
m = [self.edges[vv, uu][__MSG__] for vv in v]
f_update = self.nodes[uu].get(__UFUNC__, self.u_func)
assert f_update is not None,
assert f_update is not None, \
"Update function not registered for node %s" % uu
self.nodes[uu] = f_update(self.nodes[uu], m)
self.node[uu].update(f_update(self.nodes[uu], m))
def update_by_edge(self, u, v):
"""Trigger the message function on u->v and update v.
......@@ -283,9 +284,9 @@ class DGLGraph(DiGraph):
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
self.sendto(u, v)
self.recvfrom(v)
self.recvfrom(list(self.nodes()))
def propagate(self, iterator='bfs'):
def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator.
A convenient function for passing messages and updating
......@@ -299,6 +300,8 @@ class DGLGraph(DiGraph):
----------
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
Arguments for pre-defined iterators.
"""
if isinstance(iterator, str):
# TODO Call pre-defined routine to unroll the computation.
......
from dgl.graph import DGLGraph
def message_func(src, dst, edge):
return src['h']
def update_func(node, msgs):
m = sum(msgs)
return {'h' : node['h'] + m}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i]['h']) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.sendto(0, 1)
g.recvfrom(1, [0])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
# one-many
g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.sendto([6, 7, 8], 9)
g.recvfrom(9, [6, 7, 8])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
g.recvfrom([4, 5, 9], [[0], [0], [4, 5]])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.update_from(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
test_update_routines()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment