Unverified Commit 8ea359d1 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API] Message propagation APIs (#127)

* add examples in traversal.py

* message propagate methods

* use the new message propagation for tree-lstm

* update to the new name

* update propagate API doc

* update doc

* add propagate utest
parent 72efb427
...@@ -8,3 +8,4 @@ API Reference ...@@ -8,3 +8,4 @@ API Reference
batch batch
function function
traversal traversal
propagate
Message Propagation
===================
.. automodule:: dgl.propagate
.. autosummary::
:toctree: ../../generated/
prop_nodes
prop_edges
prop_nodes_bfs
prop_nodes_topo
prop_edges_dfs
...@@ -12,23 +12,6 @@ import dgl.ndarray as nd ...@@ -12,23 +12,6 @@ import dgl.ndarray as nd
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g._graph.adjacency_matrix().get(th.device('cuda:{}'.format(cuda)))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g._graph.adjacency_matrix().get(th.device('cpu'))
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
v = (degree == 0.).float()
v = v * mask
mask = mask - v
frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
yield frontier
degree -= th.spmm(adjmat, v)
def main(args): def main(args):
cuda = args.gpu >= 0 cuda = args.gpu >= 0
if cuda: if cuda:
...@@ -74,8 +57,7 @@ def main(args): ...@@ -74,8 +57,7 @@ def main(args):
t0 = time.time() t0 = time.time()
label = graph.ndata.pop('y') label = graph.ndata.pop('y')
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(graph, False, args)) logits = model(graph, zero_initializer, train=True)
logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, label) loss = F.nll_loss(logp, label)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -67,7 +67,7 @@ class TreeLSTM(nn.Module): ...@@ -67,7 +67,7 @@ class TreeLSTM(nn.Module):
else: else:
raise RuntimeError('Unknown cell type:', cell_type) raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True): def forward(self, graph, zero_initializer, h=None, c=None, train=True):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
...@@ -107,12 +107,7 @@ class TreeLSTM(nn.Module): ...@@ -107,12 +107,7 @@ class TreeLSTM(nn.Module):
c = zero_initializer((n, self.h_size)) c = zero_initializer((n, self.h_size))
g.ndata['c'] = c g.ndata['c'] = c
g.ndata['c_tild'] = zero_initializer((n, self.h_size)) g.ndata['c_tild'] = zero_initializer((n, self.h_size))
# TODO(minjie): potential bottleneck dgl.prop_nodes_topo(g)
if iterator is None:
g.propagate('topo')
else:
for frontier in iterator:
g.pull(frontier)
# compute logits # compute logits
h = g.ndata.pop('h') h = g.ndata.pop('h')
h = self.dropout(h) h = self.dropout(h)
......
...@@ -13,4 +13,5 @@ from .batched_graph import * ...@@ -13,4 +13,5 @@ from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
from .traversal import * from .traversal import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
...@@ -1367,11 +1367,11 @@ class DGLGraph(object): ...@@ -1367,11 +1367,11 @@ class DGLGraph(object):
---------- ----------
node_generators : generator node_generators : generator
The generator of node frontiers. The generator of node frontiers.
message_func : str or callable, optional message_func : callable, optional
The message function. The message function.
reduce_func : str or callable, optional reduce_func : callable, optional
The reduce function. The reduce function.
apply_node_func : str or callable, optional apply_node_func : callable, optional
The update function. The update function.
""" """
for node_frontier in nodes_generator: for node_frontier in nodes_generator:
...@@ -1379,7 +1379,7 @@ class DGLGraph(object): ...@@ -1379,7 +1379,7 @@ class DGLGraph(object):
message_func, reduce_func, apply_node_func) message_func, reduce_func, apply_node_func)
def prop_edges(self, def prop_edges(self,
edge_generator, edges_generator,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default"):
...@@ -1393,16 +1393,16 @@ class DGLGraph(object): ...@@ -1393,16 +1393,16 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
edge_generators : generator edges_generator : generator
The generator of edge frontiers. The generator of edge frontiers.
message_func : str or callable, optional message_func : callable, optional
The message function. The message function.
reduce_func : str or callable, optional reduce_func : callable, optional
The reduce function. The reduce function.
apply_node_func : str or callable, optional apply_node_func : callable, optional
The update function. The update function.
""" """
for edge_frontier in edge_generator: for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier, self.send_and_recv(edge_frontier,
message_func, reduce_func, apply_node_func) message_func, reduce_func, apply_node_func)
......
"""Module for message propagation."""
from __future__ import absolute_import
from . import traversal as trv
__all__ = ['prop_nodes', 'prop_nodes_bfs', 'prop_nodes_topo',
'prop_edges', 'prop_edges_dfs']
def prop_nodes(graph,
nodes_generator,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Functional method for :func:`dgl.DGLGraph.prop_nodes`.
Parameters
----------
node_generators : generator
The generator of node frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_nodes
"""
graph.prop_nodes(nodes_generator, message_func, reduce_func, apply_node_func)
def prop_edges(graph,
edges_generator,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Functional method for :func:`dgl.DGLGraph.prop_edges`.
Parameters
----------
edges_generator : generator
The generator of edge frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_edges
"""
graph.prop_edges(edges_generator, message_func, reduce_func, apply_node_func)
def prop_nodes_bfs(graph,
source,
reversed=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by BFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.bfs_nodes_generator
"""
nodes_gen = trv.bfs_nodes_generator(graph, source, reversed)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph,
reversed=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by topolocial order.
Parameters
----------
graph : DGLGraph
The graph object.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.topological_nodes_generator
"""
nodes_gen = trv.topological_nodes_generator(graph, reversed)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph,
source,
reversed=False,
has_reverse_edge=False,
has_nontree_edge=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using edge frontiers generated by labeled DFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.dfs_labeled_edges_generator
"""
edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reversed, has_reverse_edge, has_nontree_edge,
return_labels=False)
prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)
...@@ -24,6 +24,19 @@ def bfs_nodes_generator(graph, source, reversed=False): ...@@ -24,6 +24,19 @@ def bfs_nodes_generator(graph, source, reversed=False):
------- -------
list of node frontiers list of node frontiers
Each node frontier is a list, tensor of nodes. Each node frontier is a list, tensor of nodes.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
>>> g = ... # the graph above
>>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor() source = utils.toindex(source).todgltensor()
...@@ -47,6 +60,19 @@ def topological_nodes_generator(graph, reversed=False): ...@@ -47,6 +60,19 @@ def topological_nodes_generator(graph, reversed=False):
------- -------
list of node frontiers list of node frontiers
Each node frontier is a list, tensor of nodes. Each node frontier is a list, tensor of nodes.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
>>> g = ... # the graph above
>>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
ret = _CAPI_DGLTopologicalNodes(ghandle, reversed) ret = _CAPI_DGLTopologicalNodes(ghandle, reversed)
...@@ -76,6 +102,21 @@ def dfs_edges_generator(graph, source, reversed=False): ...@@ -76,6 +102,21 @@ def dfs_edges_generator(graph, source, reversed=False):
------- -------
list of edge frontiers list of edge frontiers
Each edge frontier is a list, tensor of edges. Each edge frontier is a list, tensor of edges.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = ... # the graph above
>>> list(dgl.dfs_edges_generator(g))
[tensor([0]), tensor([1]), tensor([4]), tensor([3]), tensor([5]), tensor([2])]
""" """
ghandle = graph._graph._handle ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor() source = utils.toindex(source).todgltensor()
...@@ -101,6 +142,10 @@ def dfs_labeled_edges_generator( ...@@ -101,6 +142,10 @@ def dfs_labeled_edges_generator(
edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v`
have been visisted but the edge is NOT in the DFS tree. have been visisted but the edge is NOT in the DFS tree.
See ``networkx``'s :func:`dfs_labeled_edges
<networkx.algorithms.traversal.depth_first_search.dfs_labeled_edges>`
for more details.
Multiple source nodes can be specified to start the DFS traversal. One Multiple source nodes can be specified to start the DFS traversal. One
needs to make sure that each source node belongs to different connected needs to make sure that each source node belongs to different connected
component, so the frontiers can be easily merged. Otherwise, the behavior component, so the frontiers can be easily merged. Otherwise, the behavior
......
import dgl
import networkx as nx
import torch as th
import utils as U
def mfunc(edges):
return {'m' : edges.src['x']}
def rfunc(nodes):
msg = th.sum(nodes.mailbox['m'], 1)
return {'x' : nodes.data['x'] + msg}
def test_prop_nodes_bfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.ndata['x'] = th.ones((5, 2))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
dgl.prop_nodes_bfs(g, 0)
# pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]
assert U.allclose(g.ndata['x'],
th.tensor([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0)
# snr using dfs results in a cumsum
assert U.allclose(g.ndata['x'],
th.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True)
# result is cumsum[i] + cumsum[i-1]
assert U.allclose(g.ndata['x'],
th.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True)
# result is cumsum[i] + cumsum[i+1]
assert U.allclose(g.ndata['x'],
th.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
def test_prop_nodes_topo():
# bi-directional chain
g = dgl.DGLGraph(nx.path_graph(5))
assert U.check_fail(dgl.prop_nodes_topo, g) # has loop
# tree
tree = dgl.DGLGraph()
tree.add_nodes(5)
tree.add_edge(1, 0)
tree.add_edge(2, 0)
tree.add_edge(3, 2)
tree.add_edge(4, 2)
tree.register_message_func(mfunc)
tree.register_reduce_func(rfunc)
# init node feature data
tree.ndata['x'] = th.zeros((5, 2))
# set all leaf nodes to be ones
tree.nodes[[1, 3, 4]].data['x'] = th.ones((3, 2))
dgl.prop_nodes_topo(tree)
# root node get the sum
assert U.allclose(tree.nodes[0].data['x'], th.tensor([[3., 3.]]))
if __name__ == '__main__':
test_prop_nodes_bfs()
test_prop_edges_dfs()
test_prop_nodes_topo()
...@@ -2,3 +2,10 @@ import torch as th ...@@ -2,3 +2,10 @@ import torch as th
def allclose(a, b): def allclose(a, b):
return th.allclose(a, b, rtol=1e-4, atol=1e-4) return th.allclose(a, b, rtol=1e-4, atol=1e-4)
def check_fail(fn, *args, **kwargs):
try:
fn(*args, **kwargs)
return False
except:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment