Unverified Commit c7f6cf62 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Consistent order for cross-type stack reducer (#1267)

* WIP

* fix shape of stack reducer

* apply merge order by edge id for stack reducer
parent 537d37c2
...@@ -2653,6 +2653,7 @@ class DGLHeteroGraph(object): ...@@ -2653,6 +2653,7 @@ class DGLHeteroGraph(object):
# TODO(minjie): currently loop over each edge type and reuse the old schedule. # TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel. # Should replace it with fused kernel.
all_out = [] all_out = []
merge_order = []
with ir.prog() as prog: with ir.prog() as prog:
for ety, args in reducer_dict.items(): for ety, args in reducer_dict.items():
outframe = FrameRef(frame_like(self._node_frames[ntid]._frame)) outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
...@@ -2667,9 +2668,10 @@ class DGLHeteroGraph(object): ...@@ -2667,9 +2668,10 @@ class DGLHeteroGraph(object):
v, rfunc, afunc, v, rfunc, afunc,
inplace=inplace, outframe=outframe) inplace=inplace, outframe=outframe)
all_out.append(outframe) all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog) Runtime.run(prog)
# merge by cross_reducer # merge by cross_reducer
self._node_frames[ntid].update(merge_frames(all_out, cross_reducer)) self._node_frames[ntid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, v, ntype, inplace) self.apply_nodes(apply_node_func, v, ntype, inplace)
...@@ -2855,6 +2857,7 @@ class DGLHeteroGraph(object): ...@@ -2855,6 +2857,7 @@ class DGLHeteroGraph(object):
# Should replace it with fused kernel. # Should replace it with fused kernel.
all_out = [] all_out = []
all_vs = [] all_vs = []
merge_order = []
with ir.prog() as prog: with ir.prog() as prog:
for etype, args in etype_dict.items(): for etype, args in etype_dict.items():
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
...@@ -2883,9 +2886,10 @@ class DGLHeteroGraph(object): ...@@ -2883,9 +2886,10 @@ class DGLHeteroGraph(object):
mfunc, rfunc, afunc, mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe) inplace=inplace, outframe=outframe)
all_out.append(outframe) all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog) Runtime.run(prog)
# merge by cross_reducer # merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer)) self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0)) dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
...@@ -3043,6 +3047,7 @@ class DGLHeteroGraph(object): ...@@ -3043,6 +3047,7 @@ class DGLHeteroGraph(object):
# TODO(minjie): currently loop over each edge type and reuse the old schedule. # TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel. # Should replace it with fused kernel.
all_out = [] all_out = []
merge_order = []
with ir.prog() as prog: with ir.prog() as prog:
for etype, args in etype_dict.items(): for etype, args in etype_dict.items():
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
...@@ -3058,9 +3063,10 @@ class DGLHeteroGraph(object): ...@@ -3058,9 +3063,10 @@ class DGLHeteroGraph(object):
mfunc, rfunc, afunc, mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe) inplace=inplace, outframe=outframe)
all_out.append(outframe) all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog) Runtime.run(prog)
# merge by cross_reducer # merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer)) self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, v, ntype, inplace) self.apply_nodes(apply_node_func, v, ntype, inplace)
...@@ -3263,6 +3269,7 @@ class DGLHeteroGraph(object): ...@@ -3263,6 +3269,7 @@ class DGLHeteroGraph(object):
# TODO(minjie): currently loop over each edge type and reuse the old schedule. # TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel. # Should replace it with fused kernel.
all_out = defaultdict(list) all_out = defaultdict(list)
merge_order = defaultdict(list)
with ir.prog() as prog: with ir.prog() as prog:
for etype, args in etype_dict.items(): for etype, args in etype_dict.items():
etid = self.get_etype_id(etype) etid = self.get_etype_id(etype)
...@@ -3277,10 +3284,12 @@ class DGLHeteroGraph(object): ...@@ -3277,10 +3284,12 @@ class DGLHeteroGraph(object):
mfunc, rfunc, afunc, mfunc, rfunc, afunc,
outframe=outframe) outframe=outframe)
all_out[dtid].append(outframe) all_out[dtid].append(outframe)
merge_order[dtid].append(etid) # use edge type id as merge order hint
Runtime.run(prog) Runtime.run(prog)
for dtid, frames in all_out.items(): for dtid, frames in all_out.items():
# merge by cross_reducer # merge by cross_reducer
self._node_frames[dtid].update(merge_frames(frames, cross_reducer)) self._node_frames[dtid].update(
merge_frames(frames, cross_reducer, merge_order[dtid]))
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False) self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False)
...@@ -3813,28 +3822,38 @@ def pad_tuple(tup, length, pad_val=None): ...@@ -3813,28 +3822,38 @@ def pad_tuple(tup, length, pad_val=None):
else: else:
return tup + (pad_val,) * (length - len(tup)) return tup + (pad_val,) * (length - len(tup))
def merge_frames(frames, reducer): def merge_frames(frames, reducer, order=None):
"""Merge input frames into one. Resolve conflict fields using reducer. """Merge input frames into one. Resolve conflict fields using reducer.
Parameters Parameters
---------- ----------
frames : list of FrameRef frames : list[FrameRef]
Input frames Input frames
reducer : str reducer : str
One of "sum", "max", "min", "mean", "stack" One of "sum", "max", "min", "mean", "stack"
order : list[Int], optional
Merge order hint. Useful for "stack" reducer.
If provided, each integer indicates the relative order
of the ``frames`` list. Frames are sorted according to this list
in ascending order. Tie is not handled so make sure the order values
are distinct.
Returns Returns
------- -------
FrameRef FrameRef
Merged frame Merged frame
""" """
if len(frames) == 1: if len(frames) == 1 and reducer != 'stack':
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
return frames[0] return frames[0]
if reducer == 'stack': if reducer == 'stack':
# TODO(minjie): Stack order does not matter. However, it must # Stack order does not matter. However, it must be consistent!
# be consistent! Need to enforce one type of order. if order:
assert len(order) == len(frames)
sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
frames = list(zip(*sorted_with_key))[0]
def merger(flist): def merger(flist):
flist = [F.unsqueeze(f, 1) for f in flist]
return F.stack(flist, 1) return F.stack(flist, 1)
else: else:
redfn = getattr(F, reducer, None) redfn = getattr(F, reducer, None)
...@@ -3842,7 +3861,7 @@ def merge_frames(frames, reducer): ...@@ -3842,7 +3861,7 @@ def merge_frames(frames, reducer):
raise DGLError('Invalid cross type reducer. Must be one of ' raise DGLError('Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".') '"sum", "max", "min", "mean" or "stack".')
def merger(flist): def merger(flist):
return redfn(F.stack(flist, 0), 0) return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
ret = FrameRef(frame_like(frames[0]._frame)) ret = FrameRef(frame_like(frames[0]._frame))
keys = set() keys = set()
for frm in frames: for frm in frames:
...@@ -3852,10 +3871,7 @@ def merge_frames(frames, reducer): ...@@ -3852,10 +3871,7 @@ def merge_frames(frames, reducer):
for frm in frames: for frm in frames:
if k in frm: if k in frm:
flist.append(frm[k]) flist.append(frm[k])
if len(flist) > 1: ret[k] = merger(flist)
ret[k] = merger(flist)
else:
ret[k] = flist[0]
return ret return ret
def combine_frames(frames, ids): def combine_frames(frames, ids):
......
...@@ -1249,12 +1249,10 @@ def test_level2(): ...@@ -1249,12 +1249,10 @@ def test_level2():
g['wishes'].update_all(mfunc, rfunc2) g['wishes'].update_all(mfunc, rfunc2)
y2 = g.nodes['game'].data['y'] y2 = g.nodes['game'].data['y']
if cred == 'stack': if cred == 'stack':
# stack has two both correct outcomes # stack has an internal order by edge type id
yy1 = F.stack([F.unsqueeze(y1, 1), F.unsqueeze(y2, 1)], 1) yy = F.stack([y1, y2], 1)
yy1 = yy1 + 1 # final afunc yy = yy + 1 # final afunc
yy2 = F.stack([F.unsqueeze(y2, 1), F.unsqueeze(y1, 1)], 1) assert F.array_equal(y, yy)
yy2 = yy2 + 1 # final afunc
assert F.array_equal(y, yy1) or F.array_equal(y, yy2)
else: else:
yy = get_redfn(cred)(F.stack([y1, y2], 0), 0) yy = get_redfn(cred)(F.stack([y1, y2], 0), 0)
yy = yy + 1 # final afunc yy = yy + 1 # final afunc
...@@ -1469,6 +1467,32 @@ def test_types_in_function(): ...@@ -1469,6 +1467,32 @@ def test_types_in_function():
g.filter_nodes(filter_nodes2, ntype='game') g.filter_nodes(filter_nodes2, ntype='game')
g.filter_edges(filter_edges2) g.filter_edges(filter_edges2)
def test_stack_reduce():
#edges = {
# 'follows': ([0, 1], [1, 2]),
# 'plays': ([0, 1, 2, 1], [0, 0, 1, 1]),
# 'wishes': ([0, 2], [1, 0]),
# 'develops': ([0, 1], [0, 1]),
#}
g = create_test_heterograph()
g.nodes['user'].data['h'] = F.randn((3, 200))
def rfunc(nodes):
return {'y': F.sum(nodes.mailbox['m'], 1)}
def rfunc2(nodes):
return {'y': F.max(nodes.mailbox['m'], 1)}
def mfunc(edges):
return {'m': edges.src['h']}
g.multi_update_all(
{'plays' : (mfunc, rfunc),
'wishes': (mfunc, rfunc2)},
'stack')
assert g.nodes['game'].data['y'].shape == (g.number_of_nodes('game'), 2, 200)
# only one type-wise update_all, stack still adds one dimension
g.multi_update_all(
{'plays' : (mfunc, rfunc)},
'stack')
assert g.nodes['game'].data['y'].shape == (g.number_of_nodes('game'), 1, 200)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_query() test_query()
...@@ -1491,3 +1515,4 @@ if __name__ == '__main__': ...@@ -1491,3 +1515,4 @@ if __name__ == '__main__':
test_empty_heterograph() test_empty_heterograph()
test_compact() test_compact()
test_types_in_function() test_types_in_function()
test_stack_reduce()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment