Unverified Commit 25a8e3f1 authored by Mingjian Wen's avatar Mingjian Wen Committed by GitHub
Browse files

Fix segment_reduce() ignoring tailing 0 segments (#2228)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 40caf1ab
...@@ -50,7 +50,8 @@ def segment_reduce(seglen, value, reducer='sum'): ...@@ -50,7 +50,8 @@ def segment_reduce(seglen, value, reducer='sum'):
if len(u) != len(v): if len(u) != len(v):
raise DGLError("Invalid seglen array:", seglen, raise DGLError("Invalid seglen array:", seglen,
". Its summation must be equal to value.shape[0].") ". Its summation must be equal to value.shape[0].")
g = convert.heterograph({('_U', '_E', '_V'): (u, v)}) num_nodes = {'_U': len(u), '_V': len(seglen)}
g = convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes)
g.srcdata['h'] = value g.srcdata['h'] = value
g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h')) g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h'))
return g.dstdata['h'] return g.dstdata['h']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment