Unverified Commit b2c1c4fa authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

node/edge filtering (#80)

* node/edge filtering

* changing to tensor operations (what did i do???)

* ???
parent 16da76c4
......@@ -22,3 +22,7 @@ def unpack(a, split_size_or_sections=None):
def shape(a):
return a.shape
def nonzero_1d(a):
assert a.ndim == 2
return np.nonzero(a)[0]
......@@ -143,3 +143,8 @@ def zerocopy_from_numpy(np_data):
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1
return th.nonzero(arr)[:, 0]
......@@ -3,6 +3,7 @@
from __future__ import absolute_import
import networkx as nx
import numpy as np
import dgl
from .base import ALL, is_all, __MSG__, __REPR__
......@@ -1285,3 +1286,57 @@ class DGLGraph(object):
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate.
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of node representations by node ID (same as
get_n_repr()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
nodes : container or tensor
The nodes to filter on
Returns
-------
tensor
The filtered nodes
"""
n_repr = self.get_n_repr(nodes)
n_mask = predicate(n_repr)
if is_all(nodes):
return F.nonzero_1d(n_mask)
else:
nodes = F.Tensor(nodes)
return nodes[n_mask]
def filter_edges(self, predicate, edges=ALL):
"""Return a tensor of edge IDs that satisfy the given predicate.
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as
get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
Returns
-------
tensor
The filtered edges
"""
e_repr = self.get_e_repr_by_id(edges)
e_mask = predicate(e_repr)
if is_all(edges):
return F.nonzero_1d(e_mask)
else:
edges = F.Tensor(edges)
return edges[e_mask]
import torch as th
import numpy as np
from dgl.graph import DGLGraph
def test_filter():
g = DGLGraph()
g.add_nodes(4)
g.add_edges([0,1,2,3], [1,2,3,0])
n_repr = th.zeros(4, 5)
e_repr = th.zeros(4, 5)
n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1
g.set_n_repr({'a': n_repr})
g.set_e_repr({'a': e_repr})
def predicate(r):
return r['a'].max(1)[0] > 0
# full node filter
n_idx = g.filter_nodes(predicate)
assert set(n_idx.numpy()) == {1, 3}
# partial node filter
n_idx = g.filter_nodes(predicate, [0, 1])
assert set(n_idx.numpy()) == {1}
# full edge filter
e_idx = g.filter_edges(predicate)
assert set(e_idx.numpy()) == {1, 3}
# partial edge filter
e_idx = g.filter_edges(predicate, [0, 1])
assert set(e_idx.numpy()) == {1}
if __name__ == '__main__':
test_filter()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment