Unverified Commit 8f459407 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Hotfix] Fix degree bucket edge ordering (#2176)

* fix degree bucket edge ordering

* unit test

* fix
parent cbd55eb1
"""Implementation for core graph computation.""" """Implementation for core graph computation."""
# pylint: disable=not-callable # pylint: disable=not-callable
import numpy as np
from .base import DGLError, is_all, NID, EID, ALL from .base import DGLError, is_all, NID, EID, ALL
from . import backend as F from . import backend as F
...@@ -121,8 +122,13 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None): ...@@ -121,8 +122,13 @@ def invoke_udf_reduce(graph, func, msgdata, *, orig_nid=None):
continue continue
bkt_nodes.append(node_bkt) bkt_nodes.append(node_bkt)
ndata_bkt = dstdata.subframe(node_bkt) ndata_bkt = dstdata.subframe(node_bkt)
eid_bkt = graph.in_edges(node_bkt, form='eid')
# order the incoming edges per node by edge ID
eid_bkt = F.zerocopy_to_numpy(graph.in_edges(node_bkt, form='eid'))
assert len(eid_bkt) == deg * len(node_bkt) assert len(eid_bkt) == deg * len(node_bkt)
eid_bkt = np.sort(eid_bkt.reshape((len(node_bkt), deg)), 1)
eid_bkt = F.zerocopy_from_numpy(eid_bkt.flatten())
msgdata_bkt = msgdata.subframe(eid_bkt) msgdata_bkt = msgdata.subframe(eid_bkt)
# reshape all msg tensors to (num_nodes_bkt, degree, feat_size) # reshape all msg tensors to (num_nodes_bkt, degree, feat_size)
maildata = {} maildata = {}
......
...@@ -644,3 +644,16 @@ def test_issue_1088(idtype): ...@@ -644,3 +644,16 @@ def test_issue_1088(idtype):
g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])}, idtype=idtype, device=F.ctx()) g = dgl.heterograph({('U', 'E', 'V'): ([0, 1, 2], [1, 2, 3])}, idtype=idtype, device=F.ctx())
g.nodes['U'].data['x'] = F.randn((3, 3)) g.nodes['U'].data['x'] = F.randn((3, 3))
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y')) g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
@parametrize_dtype
def test_degree_bucket_edge_ordering(idtype):
import dgl.function as fn
g = dgl.graph(
([1, 3, 5, 0, 4, 2, 3, 3, 4, 5], [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]),
idtype=idtype, device=F.ctx())
g.edata['eid'] = F.copy_to(F.arange(0, 10), F.ctx())
def reducer(nodes):
eid = F.asnumpy(F.copy_to(nodes.mailbox['eid'], F.cpu()))
assert np.array_equal(eid, np.sort(eid, 1))
return {'n': F.sum(nodes.mailbox['eid'], 1)}
g.update_all(fn.copy_e('eid', 'eid'), reducer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment