Unverified Commit cf5c1930 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix 5873 (#5884)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent c334662b
...@@ -5263,12 +5263,12 @@ class DGLGraph(object): ...@@ -5263,12 +5263,12 @@ class DGLGraph(object):
out = reduce_dict_data(frames, cross_reducer, merge_order[dtid]) out = reduce_dict_data(frames, cross_reducer, merge_order[dtid])
# Replace infinity with zero for isolated nodes when reducer is min/max # Replace infinity with zero for isolated nodes when reducer is min/max
if core.is_builtin(rfunc) and rfunc.name in ["min", "max"]: if core.is_builtin(rfunc) and rfunc.name in ["min", "max"]:
key = list(out.keys())[0] for key in out.keys():
out[key] = ( out[key] = (
F.replace_inf_with_zero(out[key]) F.replace_inf_with_zero(out[key])
if out[key] is not None if out[key] is not None
else None else None
) )
self._node_frames[dtid].update(out) self._node_frames[dtid].update(out)
# apply # apply
if apply_node_func is not None: if apply_node_func is not None:
......
...@@ -334,6 +334,38 @@ def test_binary_op(idtype): ...@@ -334,6 +334,38 @@ def test_binary_op(idtype):
_test(lhs, rhs, binary_op, reducer) _test(lhs, rhs, binary_op, reducer)
# Issue #5873
def test_multi_update_all_minmax_reduce_with_isolated_nodes():
g = dgl.heterograph(
{
("A", "AB", "B"): ([0, 1, 2, 3], [0, 0, 1, 1]),
("C", "CB", "B"): ([0, 1, 2, 3], [2, 2, 3, 3]),
},
device=F.ctx(),
)
g.nodes["A"].data["x"] = F.randn((4, 16))
g.nodes["C"].data["x"] = F.randn((4, 16))
g.multi_update_all(
{
"AB": (dgl.function.copy_u("x", "m"), dgl.function.min("m", "a1")),
"CB": (dgl.function.copy_u("x", "m"), dgl.function.min("m", "a2")),
},
cross_reducer="min",
)
assert not np.isinf(F.asnumpy(g.nodes["B"].data["a1"])).any()
assert not np.isinf(F.asnumpy(g.nodes["B"].data["a2"])).any()
g.multi_update_all(
{
"AB": (dgl.function.copy_u("x", "m"), dgl.function.max("m", "a1")),
"CB": (dgl.function.copy_u("x", "m"), dgl.function.max("m", "a2")),
},
cross_reducer="max",
)
assert not np.isinf(F.asnumpy(g.nodes["B"].data["a1"])).any()
assert not np.isinf(F.asnumpy(g.nodes["B"].data["a2"])).any()
if __name__ == "__main__": if __name__ == "__main__":
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment