Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2caac086
Unverified
Commit
2caac086
authored
Jan 06, 2021
by
Quan (Andy) Gan
Committed by
GitHub
Jan 06, 2021
Browse files
[Bug] send_and_recv and pull may write to wrong places (#2497)
* fix * fix test
parent
4507bebc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
7 deletions
+23
-7
python/dgl/heterograph.py
python/dgl/heterograph.py
+8
-7
tests/compute/test_basics.py
tests/compute/test_basics.py
+15
-0
No files found.
python/dgl/heterograph.py
View file @
2caac086
...
...
@@ -4385,9 +4385,9 @@ class DGLHeteroGraph(object):
u
,
v
=
self
.
find_edges
(
eid
,
etype
=
etype
)
# call message passing onsubgraph
g
=
self
if
etype
is
None
else
self
[
etype
]
ndata
=
core
.
message_passing
(
_create_compute_graph
(
g
,
u
,
v
,
eid
)
,
message_func
,
reduce_func
,
apply_node_func
)
dstnodes
=
F
.
unique
(
v
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
u
,
v
,
eid
)
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
pull
(
self
,
...
...
@@ -4489,9 +4489,10 @@ class DGLHeteroGraph(object):
g
=
self
if
etype
is
None
else
self
[
etype
]
# call message passing on subgraph
src
,
dst
,
eid
=
g
.
in_edges
(
v
,
form
=
'all'
)
ndata
=
core
.
message_passing
(
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
),
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
v
,
ndata
)
compute_graph
,
_
,
dstnodes
,
_
=
_create_compute_graph
(
g
,
src
,
dst
,
eid
,
v
)
ndata
=
core
.
message_passing
(
compute_graph
,
message_func
,
reduce_func
,
apply_node_func
)
self
.
_set_n_repr
(
dtid
,
dstnodes
,
ndata
)
def
push
(
self
,
u
,
...
...
@@ -6060,6 +6061,6 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
return
DGLHeteroGraph
(
hgidx
,
([
srctype
],
[
dsttype
]),
[
etype
],
node_frames
=
[
srcframe
,
dstframe
],
edge_frames
=
[
eframe
])
edge_frames
=
[
eframe
])
,
unique_src
,
unique_dst
,
eid
_init_api
(
"dgl.heterograph"
)
tests/compute/test_basics.py
View file @
2caac086
...
...
@@ -657,3 +657,18 @@ def test_degree_bucket_edge_ordering(idtype):
assert
np
.
array_equal
(
eid
,
np
.
sort
(
eid
,
1
))
return
{
'n'
:
F
.
sum
(
nodes
.
mailbox
[
'eid'
],
1
)}
g
.
update_all
(
fn
.
copy_e
(
'eid'
,
'eid'
),
reducer
)
@
parametrize_dtype
def
test_issue_2484
(
idtype
):
import
dgl.function
as
fn
g
=
dgl
.
graph
(([
0
,
1
,
2
],
[
1
,
2
,
3
]),
idtype
=
idtype
,
device
=
F
.
ctx
())
x
=
F
.
copy_to
(
F
.
randn
((
4
,)),
F
.
ctx
())
g
.
ndata
[
'x'
]
=
x
g
.
pull
([
2
,
1
],
fn
.
u_add_v
(
'x'
,
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'x'
))
y1
=
g
.
ndata
[
'x'
]
g
.
ndata
[
'x'
]
=
x
g
.
pull
([
1
,
2
],
fn
.
u_add_v
(
'x'
,
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'x'
))
y2
=
g
.
ndata
[
'x'
]
assert
F
.
allclose
(
y1
,
y2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment