"vscode:/vscode.git/clone" did not exist on "de292c7a82b8fb36531ffbf4724d1e68cb1785ee"
Unverified Commit 8ea359d1 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API] Message propagation APIs (#127)

* add examples in traversal.py

* message propagate methods

* use the new message propagation for tree-lstm

* update to the new name

* update propagate API doc

* update doc

* add propagate utest
parent 72efb427
......@@ -8,3 +8,4 @@ API Reference
batch
function
traversal
propagate
Message Propagation
===================
.. automodule:: dgl.propagate
.. autosummary::
:toctree: ../../generated/
prop_nodes
prop_edges
prop_nodes_bfs
prop_nodes_topo
prop_edges_dfs
......@@ -12,23 +12,6 @@ import dgl.ndarray as nd
from tree_lstm import TreeLSTM
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g._graph.adjacency_matrix().get(th.device('cuda:{}'.format(cuda)))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g._graph.adjacency_matrix().get(th.device('cpu'))
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
v = (degree == 0.).float()
v = v * mask
mask = mask - v
frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
yield frontier
degree -= th.spmm(adjmat, v)
def main(args):
cuda = args.gpu >= 0
if cuda:
......@@ -74,8 +57,7 @@ def main(args):
t0 = time.time()
label = graph.ndata.pop('y')
# traverse graph
giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True)
logits = model(graph, zero_initializer, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, label)
optimizer.zero_grad()
......
......@@ -67,7 +67,7 @@ class TreeLSTM(nn.Module):
else:
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
def forward(self, graph, zero_initializer, h=None, c=None, train=True):
"""Compute tree-lstm prediction given a batch.
Parameters
......@@ -107,12 +107,7 @@ class TreeLSTM(nn.Module):
c = zero_initializer((n, self.h_size))
g.ndata['c'] = c
g.ndata['c_tild'] = zero_initializer((n, self.h_size))
# TODO(minjie): potential bottleneck
if iterator is None:
g.propagate('topo')
else:
for frontier in iterator:
g.pull(frontier)
dgl.prop_nodes_topo(g)
# compute logits
h = g.ndata.pop('h')
h = self.dropout(h)
......
......@@ -13,4 +13,5 @@ from .batched_graph import *
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .traversal import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch
......@@ -1367,11 +1367,11 @@ class DGLGraph(object):
----------
node_generators : generator
The generator of node frontiers.
message_func : str or callable, optional
message_func : callable, optional
The message function.
reduce_func : str or callable, optional
reduce_func : callable, optional
The reduce function.
apply_node_func : str or callable, optional
apply_node_func : callable, optional
The update function.
"""
for node_frontier in nodes_generator:
......@@ -1379,7 +1379,7 @@ class DGLGraph(object):
message_func, reduce_func, apply_node_func)
def prop_edges(self,
edge_generator,
edges_generator,
message_func="default",
reduce_func="default",
apply_node_func="default"):
......@@ -1393,16 +1393,16 @@ class DGLGraph(object):
Parameters
----------
edge_generators : generator
edges_generator : generator
The generator of edge frontiers.
message_func : str or callable, optional
message_func : callable, optional
The message function.
reduce_func : str or callable, optional
reduce_func : callable, optional
The reduce function.
apply_node_func : str or callable, optional
apply_node_func : callable, optional
The update function.
"""
for edge_frontier in edge_generator:
for edge_frontier in edges_generator:
self.send_and_recv(edge_frontier,
message_func, reduce_func, apply_node_func)
......
"""Module for message propagation."""
from __future__ import absolute_import
from . import traversal as trv
__all__ = ['prop_nodes', 'prop_nodes_bfs', 'prop_nodes_topo',
'prop_edges', 'prop_edges_dfs']
def prop_nodes(graph,
nodes_generator,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Functional method for :func:`dgl.DGLGraph.prop_nodes`.
Parameters
----------
node_generators : generator
The generator of node frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_nodes
"""
graph.prop_nodes(nodes_generator, message_func, reduce_func, apply_node_func)
def prop_edges(graph,
edges_generator,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Functional method for :func:`dgl.DGLGraph.prop_edges`.
Parameters
----------
edges_generator : generator
The generator of edge frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_edges
"""
graph.prop_edges(edges_generator, message_func, reduce_func, apply_node_func)
def prop_nodes_bfs(graph,
source,
reversed=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by BFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.bfs_nodes_generator
"""
nodes_gen = trv.bfs_nodes_generator(graph, source, reversed)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_nodes_topo(graph,
reversed=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using node frontiers generated by topolocial order.
Parameters
----------
graph : DGLGraph
The graph object.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.topological_nodes_generator
"""
nodes_gen = trv.topological_nodes_generator(graph, reversed)
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
def prop_edges_dfs(graph,
source,
reversed=False,
has_reverse_edge=False,
has_nontree_edge=False,
message_func='default',
reduce_func='default',
apply_node_func='default'):
"""Message propagation using edge frontiers generated by labeled DFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
reversed : bool, optional
If true, traverse following the in-edge direction.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.dfs_labeled_edges_generator
"""
edges_gen = trv.dfs_labeled_edges_generator(
graph, source, reversed, has_reverse_edge, has_nontree_edge,
return_labels=False)
prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)
......@@ -24,6 +24,19 @@ def bfs_nodes_generator(graph, source, reversed=False):
-------
list of node frontiers
Each node frontier is a list, tensor of nodes.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
>>> g = ... # the graph above
>>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
......@@ -47,6 +60,19 @@ def topological_nodes_generator(graph, reversed=False):
-------
list of node frontiers
Each node frontier is a list, tensor of nodes.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
>>> g = ... # the graph above
>>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
"""
ghandle = graph._graph._handle
ret = _CAPI_DGLTopologicalNodes(ghandle, reversed)
......@@ -76,6 +102,21 @@ def dfs_edges_generator(graph, source, reversed=False):
-------
list of edge frontiers
Each edge frontier is a list, tensor of edges.
Examples
--------
Given a graph (directed, edges from small node id to large):
::
2 - 4
/ \
0 - 1 - 3 - 5
Edge addition order [(0, 1), (1, 2), (1, 3), (2, 3), (2, 4), (3, 5)]
>>> g = ... # the graph above
>>> list(dgl.dfs_edges_generator(g))
[tensor([0]), tensor([1]), tensor([4]), tensor([3]), tensor([5]), tensor([2])]
"""
ghandle = graph._graph._handle
source = utils.toindex(source).todgltensor()
......@@ -101,6 +142,10 @@ def dfs_labeled_edges_generator(
edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v`
have been visisted but the edge is NOT in the DFS tree.
See ``networkx``'s :func:`dfs_labeled_edges
<networkx.algorithms.traversal.depth_first_search.dfs_labeled_edges>`
for more details.
Multiple source nodes can be specified to start the DFS traversal. One
needs to make sure that each source node belongs to different connected
component, so the frontiers can be easily merged. Otherwise, the behavior
......
import dgl
import networkx as nx
import torch as th
import utils as U
def mfunc(edges):
return {'m' : edges.src['x']}
def rfunc(nodes):
msg = th.sum(nodes.mailbox['m'], 1)
return {'x' : nodes.data['x'] + msg}
def test_prop_nodes_bfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.ndata['x'] = th.ones((5, 2))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
dgl.prop_nodes_bfs(g, 0)
# pull nodes using bfs order will result in a cumsum[i] + data[i] + data[i+1]
assert U.allclose(g.ndata['x'],
th.tensor([[2., 2.], [4., 4.], [6., 6.], [8., 8.], [9., 9.]]))
def test_prop_edges_dfs():
g = dgl.DGLGraph(nx.path_graph(5))
g.register_message_func(mfunc)
g.register_reduce_func(rfunc)
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0)
# snr using dfs results in a cumsum
assert U.allclose(g.ndata['x'],
th.tensor([[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]))
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_reverse_edge=True)
# result is cumsum[i] + cumsum[i-1]
assert U.allclose(g.ndata['x'],
th.tensor([[1., 1.], [3., 3.], [5., 5.], [7., 7.], [9., 9.]]))
g.ndata['x'] = th.ones((5, 2))
dgl.prop_edges_dfs(g, 0, has_nontree_edge=True)
# result is cumsum[i] + cumsum[i+1]
assert U.allclose(g.ndata['x'],
th.tensor([[3., 3.], [5., 5.], [7., 7.], [9., 9.], [5., 5.]]))
def test_prop_nodes_topo():
# bi-directional chain
g = dgl.DGLGraph(nx.path_graph(5))
assert U.check_fail(dgl.prop_nodes_topo, g) # has loop
# tree
tree = dgl.DGLGraph()
tree.add_nodes(5)
tree.add_edge(1, 0)
tree.add_edge(2, 0)
tree.add_edge(3, 2)
tree.add_edge(4, 2)
tree.register_message_func(mfunc)
tree.register_reduce_func(rfunc)
# init node feature data
tree.ndata['x'] = th.zeros((5, 2))
# set all leaf nodes to be ones
tree.nodes[[1, 3, 4]].data['x'] = th.ones((3, 2))
dgl.prop_nodes_topo(tree)
# root node get the sum
assert U.allclose(tree.nodes[0].data['x'], th.tensor([[3., 3.]]))
if __name__ == '__main__':
test_prop_nodes_bfs()
test_prop_edges_dfs()
test_prop_nodes_topo()
......@@ -2,3 +2,10 @@ import torch as th
def allclose(a, b):
return th.allclose(a, b, rtol=1e-4, atol=1e-4)
def check_fail(fn, *args, **kwargs):
try:
fn(*args, **kwargs)
return False
except:
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment