Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
23e2e83b
Commit
23e2e83b
authored
Nov 23, 2018
by
Zihao Ye
Committed by
Minjie Wang
Nov 22, 2018
Browse files
[API] change the signature of node/edge filter (#166)
* change the signature of node/edge filter * upd filter
parent
deb653f8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
13 deletions
+36
-13
python/dgl/graph.py
python/dgl/graph.py
+35
-12
tests/pytorch/test_filter.py
tests/pytorch/test_filter.py
+1
-1
No files found.
python/dgl/graph.py
View file @
23e2e83b
...
@@ -12,6 +12,7 @@ from .graph_index import GraphIndex, create_graph_index
...
@@ -12,6 +12,7 @@ from .graph_index import GraphIndex, create_graph_index
from
.runtime
import
ir
,
scheduler
,
Runtime
from
.runtime
import
ir
,
scheduler
,
Runtime
from
.
import
utils
from
.
import
utils
from
.view
import
NodeView
,
EdgeView
from
.view
import
NodeView
,
EdgeView
from
.udf
import
NodeBatch
,
EdgeBatch
__all__
=
[
'DGLGraph'
]
__all__
=
[
'DGLGraph'
]
...
@@ -1563,10 +1564,9 @@ class DGLGraph(object):
...
@@ -1563,10 +1564,9 @@ class DGLGraph(object):
Parameters
Parameters
----------
----------
predicate : callable
predicate : callable
The predicate should take in a dict of tensors whose values
The predicate should take in a NodeBatch object, and return a
are concatenation of node representations by node ID (same as
boolean tensor with N elements indicating which node satisfy
get_n_repr()), and return a boolean tensor with N elements
the predicate.
indicating which node satisfy the predicate.
nodes : container or tensor
nodes : container or tensor
The nodes to filter on
The nodes to filter on
...
@@ -1575,8 +1575,14 @@ class DGLGraph(object):
...
@@ -1575,8 +1575,14 @@ class DGLGraph(object):
tensor
tensor
The filtered nodes
The filtered nodes
"""
"""
n_repr
=
self
.
get_n_repr
(
nodes
)
if
is_all
(
nodes
):
n_mask
=
predicate
(
n_repr
)
v
=
utils
.
toindex
(
slice
(
0
,
self
.
number_of_nodes
()))
else
:
v
=
utils
.
toindex
(
nodes
)
n_repr
=
self
.
get_n_repr
(
v
)
nb
=
NodeBatch
(
self
,
v
,
n_repr
)
n_mask
=
predicate
(
nb
)
if
is_all
(
nodes
):
if
is_all
(
nodes
):
return
F
.
nonzero_1d
(
n_mask
)
return
F
.
nonzero_1d
(
n_mask
)
...
@@ -1590,10 +1596,9 @@ class DGLGraph(object):
...
@@ -1590,10 +1596,9 @@ class DGLGraph(object):
Parameters
Parameters
----------
----------
predicate : callable
predicate : callable
The predicate should take in a dict of tensors whose values
The predicate should take in an EdgeBatch object, and return a
are concatenation of edge representations by edge ID,
boolean tensor with E elements indicating which edge satisfy
and return a boolean tensor with N elements indicating which
the predicate.
node satisfy the predicate.
edges : edges
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
tensor of edge ids. The default value is all the edges.
...
@@ -1603,8 +1608,26 @@ class DGLGraph(object):
...
@@ -1603,8 +1608,26 @@ class DGLGraph(object):
tensor
tensor
The filtered edges
The filtered edges
"""
"""
e_repr
=
self
.
get_e_repr
(
edges
)
if
is_all
(
edges
):
e_mask
=
predicate
(
e_repr
)
eid
=
ALL
u
,
v
,
_
=
self
.
_graph
.
edges
()
elif
isinstance
(
edges
,
tuple
):
u
,
v
=
edges
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
# Rewrite u, v to handle edge broadcasting and multigraph.
u
,
v
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
else
:
eid
=
utils
.
toindex
(
edges
)
u
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
src_data
=
self
.
get_n_repr
(
u
)
edge_data
=
self
.
get_e_repr
(
eid
)
dst_data
=
self
.
get_n_repr
(
v
)
eb
=
EdgeBatch
(
self
,
(
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
)
e_mask
=
predicate
(
eb
)
if
is_all
(
edges
):
if
is_all
(
edges
):
return
F
.
nonzero_1d
(
e_mask
)
return
F
.
nonzero_1d
(
e_mask
)
...
...
tests/pytorch/test_filter.py
View file @
23e2e83b
...
@@ -17,7 +17,7 @@ def test_filter():
...
@@ -17,7 +17,7 @@ def test_filter():
g
.
edata
[
'a'
]
=
e_repr
g
.
edata
[
'a'
]
=
e_repr
def
predicate
(
r
):
def
predicate
(
r
):
return
r
[
'a'
].
max
(
1
)[
0
]
>
0
return
r
.
data
[
'a'
].
max
(
1
)[
0
]
>
0
# full node filter
# full node filter
n_idx
=
g
.
filter_nodes
(
predicate
)
n_idx
=
g
.
filter_nodes
(
predicate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment