Commit 23e2e83b authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[API] change the signature of node/edge filter (#166)

* change the signature of node/edge filter

* upd filter
parent deb653f8
......@@ -12,6 +12,7 @@ from .graph_index import GraphIndex, create_graph_index
from .runtime import ir, scheduler, Runtime
from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph']
......@@ -1563,10 +1564,9 @@ class DGLGraph(object):
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of node representations by node ID (same as
get_n_repr()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
The predicate should take in a NodeBatch object, and return a
boolean tensor with N elements indicating which node satisfy
the predicate.
nodes : container or tensor
The nodes to filter on
......@@ -1575,8 +1575,14 @@ class DGLGraph(object):
tensor
The filtered nodes
"""
n_repr = self.get_n_repr(nodes)
n_mask = predicate(n_repr)
if is_all(nodes):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(nodes)
n_repr = self.get_n_repr(v)
nb = NodeBatch(self, v, n_repr)
n_mask = predicate(nb)
if is_all(nodes):
return F.nonzero_1d(n_mask)
......@@ -1590,10 +1596,9 @@ class DGLGraph(object):
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID,
and return a boolean tensor with N elements indicating which
node satisfy the predicate.
The predicate should take in an EdgeBatch object, and return a
boolean tensor with E elements indicating which edge satisfy
the predicate.
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
......@@ -1603,8 +1608,26 @@ class DGLGraph(object):
tensor
The filtered edges
"""
e_repr = self.get_e_repr(edges)
e_mask = predicate(e_repr)
if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
e_mask = predicate(eb)
if is_all(edges):
return F.nonzero_1d(e_mask)
......
......@@ -17,7 +17,7 @@ def test_filter():
g.edata['a'] = e_repr
def predicate(r):
return r['a'].max(1)[0] > 0
return r.data['a'].max(1)[0] > 0
# full node filter
n_idx = g.filter_nodes(predicate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment