Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2c170a8c
Commit
2c170a8c
authored
Dec 03, 2018
by
Da Zheng
Committed by
Minjie Wang
Dec 03, 2018
Browse files
[Graph][Bugfix] Fix the API of map_to_subgraph_nid (#226)
* correct vid mapping API. * fix sse.
parent
419ffbde
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
14 deletions
+25
-14
examples/mxnet/sse/sse_batch.py
examples/mxnet/sse/sse_batch.py
+1
-1
python/dgl/subgraph.py
python/dgl/subgraph.py
+13
-1
tests/mxnet/test_sampler.py
tests/mxnet/test_sampler.py
+11
-12
No files found.
examples/mxnet/sse/sse_batch.py
View file @
2c170a8c
...
...
@@ -283,7 +283,7 @@ def main(args, data):
copy_to_gpu
(
subg
,
ctx
)
with
mx
.
autograd
.
record
():
logits
=
model_train
(
subg
,
subg_seeds
.
tousertensor
()
)
logits
=
model_train
(
subg
,
subg_seeds
)
batch_labels
=
mx
.
nd
.
take
(
labels
,
seeds
).
as_in_context
(
logits
.
context
)
loss
=
mx
.
nd
.
softmax_cross_entropy
(
logits
,
batch_labels
)
loss
.
backward
()
...
...
python/dgl/subgraph.py
View file @
2c170a8c
...
...
@@ -128,4 +128,16 @@ class DGLSubGraph(DGLGraph):
self
.
_parent
.
_edge_frame
[
self
.
_get_parent_eid
()]))
def
map_to_subgraph_nid
(
self
,
parent_vids
):
return
map_to_subgraph_nid
(
self
.
_graph
,
utils
.
toindex
(
parent_vids
))
"""Map the node Ids in the parent graph to the node Ids in the subgraph.
Parameters
----------
parent_vids : list, tensor
The node ID array in the parent graph.
Returns
-------
tensor
The node ID array in the subgraph.
"""
return
map_to_subgraph_nid
(
self
.
_graph
,
utils
.
toindex
(
parent_vids
)).
tousertensor
()
tests/mxnet/test_sampler.py
View file @
2c170a8c
...
...
@@ -16,9 +16,9 @@ def test_1neighbor_sampler_all():
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
1
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
assert
len
(
seed_ids
)
==
1
src
,
dst
,
eid
=
g
.
_graph
.
in_edges
(
utils
.
toindex
(
seed_ids
)
)
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
# Test if there is a self loop
self_loop
=
mx
.
nd
.
sum
(
src
.
tousertensor
()
==
dst
.
tousertensor
()
).
asnumpy
()
==
1
self_loop
=
mx
.
nd
.
sum
(
src
==
dst
).
asnumpy
()
==
1
if
self_loop
:
assert
subg
.
number_of_nodes
()
==
len
(
src
)
else
:
...
...
@@ -26,26 +26,25 @@ def test_1neighbor_sampler_all():
assert
subg
.
number_of_edges
()
>=
len
(
src
)
child_ids
=
subg
.
map_to_subgraph_nid
(
seed_ids
)
child_src
,
child_dst
,
child_eid
=
subg
.
_graph
.
in_edges
(
child_ids
)
child_src
,
child_dst
,
child_eid
=
subg
.
in_edges
(
child_ids
,
form
=
'all'
)
child_src1
=
subg
.
map_to_subgraph_nid
(
src
)
assert
mx
.
nd
.
sum
(
child_src1
.
tousertensor
()
==
child_src
.
tousertensor
()
).
asnumpy
()
==
len
(
src
)
assert
mx
.
nd
.
sum
(
child_src1
==
child_src
).
asnumpy
()
==
len
(
src
)
def
is_sorted
(
arr
):
return
np
.
sum
(
np
.
sort
(
arr
)
==
arr
)
==
len
(
arr
)
def
verify_subgraph
(
g
,
subg
,
seed_id
):
seed_id
=
utils
.
toindex
(
seed_id
)
src
,
dst
,
eid
=
g
.
_graph
.
in_edges
(
utils
.
toindex
(
seed_id
))
src
,
dst
,
eid
=
g
.
in_edges
(
seed_id
,
form
=
'all'
)
child_id
=
subg
.
map_to_subgraph_nid
(
seed_id
)
child_src
,
child_dst
,
child_eid
=
subg
.
_graph
.
in_edges
(
child_id
)
child_src
=
child_src
.
tousertensor
().
asnumpy
()
child_src
,
child_dst
,
child_eid
=
subg
.
in_edges
(
child_id
,
form
=
'all'
)
child_src
=
child_src
.
asnumpy
()
# We don't allow duplicate elements in the neighbor list.
assert
(
len
(
np
.
unique
(
child_src
))
==
len
(
child_src
))
# The neighbor list also needs to be sorted.
assert
(
is_sorted
(
child_src
))
child_src1
=
subg
.
map_to_subgraph_nid
(
src
).
tousertensor
().
asnumpy
()
child_src1
=
subg
.
map_to_subgraph_nid
(
src
).
asnumpy
()
child_src1
=
child_src1
[
child_src1
>=
0
]
for
i
in
child_src
:
assert
i
in
child_src1
...
...
@@ -65,13 +64,13 @@ def test_10neighbor_sampler_all():
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for
subg
,
seed_ids
in
dgl
.
contrib
.
sampling
.
NeighborSampler
(
g
,
10
,
100
,
neighbor_type
=
'in'
,
num_workers
=
4
):
src
,
dst
,
eid
=
g
.
_graph
.
in_edges
(
utils
.
toindex
(
seed_ids
)
)
src
,
dst
,
eid
=
g
.
in_edges
(
seed_ids
,
form
=
'all'
)
child_ids
=
subg
.
map_to_subgraph_nid
(
seed_ids
)
child_src
,
child_dst
,
child_eid
=
subg
.
_graph
.
in_edges
(
child_ids
)
child_src
,
child_dst
,
child_eid
=
subg
.
in_edges
(
child_ids
,
form
=
'all'
)
child_src1
=
subg
.
map_to_subgraph_nid
(
src
)
assert
mx
.
nd
.
sum
(
child_src1
.
tousertensor
()
==
child_src
.
tousertensor
()
).
asnumpy
()
==
len
(
src
)
assert
mx
.
nd
.
sum
(
child_src1
==
child_src
).
asnumpy
()
==
len
(
src
)
def
check_10neighbor_sampler
(
g
,
seeds
):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment