Unverified Commit 2caac086 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] send_and_recv and pull may write to wrong places (#2497)

* fix

* fix test
parent 4507bebc
......@@ -4385,9 +4385,9 @@ class DGLHeteroGraph(object):
u, v = self.find_edges(eid, etype=etype)
# call message passing onsubgraph
g = self if etype is None else self[etype]
ndata = core.message_passing(_create_compute_graph(g, u, v, eid),
message_func, reduce_func, apply_node_func)
dstnodes = F.unique(v)
compute_graph, _, dstnodes, _ = _create_compute_graph(g, u, v, eid)
ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, dstnodes, ndata)
def pull(self,
......@@ -4489,9 +4489,10 @@ class DGLHeteroGraph(object):
g = self if etype is None else self[etype]
# call message passing on subgraph
src, dst, eid = g.in_edges(v, form='all')
ndata = core.message_passing(_create_compute_graph(g, src, dst, eid, v),
message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, v, ndata)
compute_graph, _, dstnodes, _ = _create_compute_graph(g, src, dst, eid, v)
ndata = core.message_passing(
compute_graph, message_func, reduce_func, apply_node_func)
self._set_n_repr(dtid, dstnodes, ndata)
def push(self,
u,
......@@ -6060,6 +6061,6 @@ def _create_compute_graph(graph, u, v, eid, recv_nodes=None):
return DGLHeteroGraph(hgidx, ([srctype], [dsttype]), [etype],
node_frames=[srcframe, dstframe],
edge_frames=[eframe])
edge_frames=[eframe]), unique_src, unique_dst, eid
_init_api("dgl.heterograph")
......@@ -657,3 +657,18 @@ def test_degree_bucket_edge_ordering(idtype):
assert np.array_equal(eid, np.sort(eid, 1))
return {'n': F.sum(nodes.mailbox['eid'], 1)}
g.update_all(fn.copy_e('eid', 'eid'), reducer)
@parametrize_dtype
def test_issue_2484(idtype):
import dgl.function as fn
g = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
x = F.copy_to(F.randn((4,)), F.ctx())
g.ndata['x'] = x
g.pull([2, 1], fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'x'))
y1 = g.ndata['x']
g.ndata['x'] = x
g.pull([1, 2], fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'x'))
y2 = g.ndata['x']
assert F.allclose(y1, y2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment