Commit 23e2e83b authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[API] change the signature of node/edge filter (#166)

* change the signature of node/edge filter

* upd filter
parent deb653f8
...@@ -12,6 +12,7 @@ from .graph_index import GraphIndex, create_graph_index ...@@ -12,6 +12,7 @@ from .graph_index import GraphIndex, create_graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import utils from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph'] __all__ = ['DGLGraph']
...@@ -1563,10 +1564,9 @@ class DGLGraph(object): ...@@ -1563,10 +1564,9 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
predicate : callable predicate : callable
The predicate should take in a dict of tensors whose values The predicate should take in a NodeBatch object, and return a
are concatenation of node representations by node ID (same as boolean tensor with N elements indicating which node satisfy
get_n_repr()), and return a boolean tensor with N elements the predicate.
indicating which node satisfy the predicate.
nodes : container or tensor nodes : container or tensor
The nodes to filter on The nodes to filter on
...@@ -1575,8 +1575,14 @@ class DGLGraph(object): ...@@ -1575,8 +1575,14 @@ class DGLGraph(object):
tensor tensor
The filtered nodes The filtered nodes
""" """
n_repr = self.get_n_repr(nodes) if is_all(nodes):
n_mask = predicate(n_repr) v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(nodes)
n_repr = self.get_n_repr(v)
nb = NodeBatch(self, v, n_repr)
n_mask = predicate(nb)
if is_all(nodes): if is_all(nodes):
return F.nonzero_1d(n_mask) return F.nonzero_1d(n_mask)
...@@ -1590,10 +1596,9 @@ class DGLGraph(object): ...@@ -1590,10 +1596,9 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
predicate : callable predicate : callable
The predicate should take in a dict of tensors whose values The predicate should take in an EdgeBatch object, and return a
are concatenation of edge representations by edge ID, boolean tensor with E elements indicating which edge satisfy
and return a boolean tensor with N elements indicating which the predicate.
node satisfy the predicate.
edges : edges edges : edges
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
...@@ -1603,8 +1608,26 @@ class DGLGraph(object): ...@@ -1603,8 +1608,26 @@ class DGLGraph(object):
tensor tensor
The filtered edges The filtered edges
""" """
e_repr = self.get_e_repr(edges) if is_all(edges):
e_mask = predicate(e_repr) eid = ALL
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
e_mask = predicate(eb)
if is_all(edges): if is_all(edges):
return F.nonzero_1d(e_mask) return F.nonzero_1d(e_mask)
......
...@@ -17,7 +17,7 @@ def test_filter(): ...@@ -17,7 +17,7 @@ def test_filter():
g.edata['a'] = e_repr g.edata['a'] = e_repr
def predicate(r): def predicate(r):
return r['a'].max(1)[0] > 0 return r.data['a'].max(1)[0] > 0
# full node filter # full node filter
n_idx = g.filter_nodes(predicate) n_idx = g.filter_nodes(predicate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment