Unverified Commit 6c81634b authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Bugfix] Fixes wrong output in `multi_update_all` for reduce op max/min (Issue#3564)) (#3581)



* fixed bug

* added  in test cases

* unittest resolved

* bugfix
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 70a499e3
......@@ -4866,6 +4866,10 @@ class DGLHeteroGraph(object):
_, dtid = self._graph.metagraph.find_edge(etid)
g = self if etype is None else self[etype]
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max'] and ndata:
# Replace infinity with zero for isolated nodes
key = list(ndata.keys())[0]
ndata[key] = F.replace_inf_with_zero(ndata[key])
self._set_n_repr(dtid, ALL, ndata)
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(message_func) or not core.is_builtin(reduce_func):
......@@ -4885,6 +4889,8 @@ class DGLHeteroGraph(object):
for _, _, dsttype in g.canonical_etypes:
dtid = g.get_ntype_id(dsttype)
dst_tensor[key] = out_tensor_tuples[dtid]
if core.is_builtin(reduce_func) and reduce_func.name in ['min', 'max']:
dst_tensor[key] = F.replace_inf_with_zero(dst_tensor[key])
self._node_frames[dtid].update(dst_tensor)
#################################################################
......@@ -4978,6 +4984,7 @@ class DGLHeteroGraph(object):
all_out = defaultdict(list)
merge_order = defaultdict(list)
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
_, dtid = self._graph.metagraph.find_edge(etid)
args = pad_tuple(args, 3)
......@@ -4990,13 +4997,18 @@ class DGLHeteroGraph(object):
merge_order[dtid].append(etid) # use edge type id as merge order hint
for dtid, frames in all_out.items():
# merge by cross_reducer
self._node_frames[dtid].update(
reduce_dict_data(frames, cross_reducer, merge_order[dtid]))
out = reduce_dict_data(frames, cross_reducer, merge_order[dtid])
# Replace infinity with zero for isolated nodes when reducer is min/max
if core.is_builtin(rfunc) and rfunc.name in ['min', 'max']:
key = list(out.keys())[0]
out[key] = F.replace_inf_with_zero(out[key]) if out[key] is not None else None
self._node_frames[dtid].update(out)
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid])
#################################################################
# Message propagation
#################################################################
......
......@@ -75,9 +75,6 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
lhs_data, rhs_data)
# Replace infinity with zero for isolated nodes when reducer is min/max
if reduce_op in ['min', 'max']:
ret = F.replace_inf_with_zero(ret)
else:
# lhs_data or rhs_data is None only in unary functions like ``copy-u`` or ``copy_e``
lhs_data = [None] * g._graph.number_of_ntypes() if lhs_data is None else lhs_data
......@@ -87,13 +84,6 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
ret = gspmm_internal_hetero(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
len(lhs_data), *lhs_and_rhs_tuple)
# `update_all` on heterogeneous graphs replaces the inf values with zeros on
# the final output (after processing all etypes). `multi_update_all` performs
# this operation after processing each etype. It computes the final output based
# on the output of each etype where inf is already replaced by zero.
if reduce_op in ['min', 'max']:
ret = tuple([F.replace_inf_with_zero(ret[i]) if ret[i] is not None else None
for i in range(len(ret))])
# TODO (Israt): Add support for 'mean' in heterograph
# divide in degrees for mean reducer.
if reduce_op == 'mean':
......
......@@ -38,14 +38,18 @@ def create_test_heterograph(idtype):
def create_test_heterograph_2(idtype):
src = np.random.randint(0, 5, 25)
dst = np.random.randint(0, 5, 25)
src = np.random.randint(0, 50, 25)
dst = np.random.randint(0, 50, 25)
src1 = np.random.randint(0, 25, 10)
dst1 = np.random.randint(0, 25, 10)
src2 = np.random.randint(0, 100, 1000)
dst2 = np.random.randint(0, 100, 1000)
g = dgl.heterograph({
('user', 'becomes', 'player'): (src, dst),
('user', 'follows', 'user'): (src, dst),
('user', 'plays', 'game'): (src, dst),
('user', 'wishes', 'game'): (src, dst),
('developer', 'develops', 'game'): (src, dst),
('user', 'wishes', 'game'): (src1, dst1),
('developer', 'develops', 'game'): (src2, dst2),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
......
......@@ -117,6 +117,8 @@ def test_spmm(idtype, g, shp, msg, reducer):
e = F.attach_grad(F.clone(he))
with F.record_grad():
v = gspmm(g, msg, reducer, u, e)
if reducer in ['max', 'min']:
v = F.replace_inf_with_zero(v)
if g.number_of_edges() > 0:
F.backward(F.reduce_sum(v))
if msg != 'copy_rhs':
......@@ -270,6 +272,8 @@ def test_segment_reduce(reducer):
g = dgl.convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes)
with F.record_grad():
rst1 = gspmm(g, 'copy_lhs', reducer, v1, None)
if reducer in ['max', 'min']:
rst1 = F.replace_inf_with_zero(rst1)
F.backward(F.reduce_sum(rst1))
grad1 = F.grad(v1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment