Unverified Commit 6294677f authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Set reduce results to all zero for nodes with zero in-degrees. (#2011)

parent 53629082
...@@ -277,7 +277,7 @@ def full_1d(length, fill_value, dtype, ctx): ...@@ -277,7 +277,7 @@ def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx) return th.full((length,), fill_value, dtype=dtype, device=ctx)
def nonzero_1d(input): def nonzero_1d(input):
x = th.nonzero(input).squeeze() x = th.nonzero(input, as_tuple=False).squeeze()
return x if x.dim() == 1 else x.view(-1) return x if x.dim() == 1 else x.view(-1)
def sort_1d(input): def sort_1d(input):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import sys import sys
from ..base import dgl_warning from ..base import dgl_warning
from ..backend import gspmm as gspmm_internal from ..backend import gspmm as gspmm_internal, backend_name
from .. import backend as F from .. import backend as F
__all__ = ['gspmm'] __all__ = ['gspmm']
...@@ -59,17 +59,33 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -59,17 +59,33 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:] new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape) lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape) rhs_data = F.reshape(rhs_data, new_rhs_shape)
ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
lhs_data, rhs_data)
# assign zero features for zero degree nodes.
deg = g.in_degrees()
min_deg = F.as_scalar(F.min(deg, dim=0))
if min_deg == 0:
non_zero_nids = F.nonzero_1d(deg == 0)
if backend_name == 'pytorch':
ret[non_zero_nids] = 0.
else:
dtype = F.dtype(ret)
ctx = F.context(ret)
ret = F.scatter_row(ret, non_zero_nids,
F.zeros((len(non_zero_nids),) + F.shape(ret)[1:], dtype, ctx))
# divide in degrees for mean reducer.
if reduce_op == 'mean': if reduce_op == 'mean':
ret = gspmm_internal(g._graph, op, 'sum', lhs_data, rhs_data)
ret_shape = F.shape(ret) ret_shape = F.shape(ret)
deg = g.in_degrees() if min_deg == 0:
if F.as_scalar(F.min(deg, dim=0)) == 0:
dgl_warning('Zero-degree nodes encountered in mean reducer. Setting the mean to 0.') dgl_warning('Zero-degree nodes encountered in mean reducer. Setting the mean to 0.')
deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret)) deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret))
deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1) deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1)
return ret / F.reshape(deg, deg_shape) return ret / F.reshape(deg, deg_shape)
else: else:
return gspmm_internal(g._graph, op, reduce_op, lhs_data, rhs_data) return ret
def _gen_spmm_func(binary_op, reduce_op): def _gen_spmm_func(binary_op, reduce_op):
......
...@@ -116,9 +116,6 @@ def test_spmm(idtype, g, shp, msg, reducer): ...@@ -116,9 +116,6 @@ def test_spmm(idtype, g, shp, msg, reducer):
e = F.attach_grad(F.clone(he)) e = F.attach_grad(F.clone(he))
with F.record_grad(): with F.record_grad():
v = gspmm(g, msg, reducer, u, e) v = gspmm(g, msg, reducer, u, e)
non_degree_indices = F.tensor(
np.nonzero(F.asnumpy(g.in_degrees()) != 0)[0])
v = F.gather_row(v, non_degree_indices)
if g.number_of_edges() > 0: if g.number_of_edges() > 0:
F.backward(F.reduce_sum(v)) F.backward(F.reduce_sum(v))
if msg != 'copy_rhs': if msg != 'copy_rhs':
...@@ -129,7 +126,7 @@ def test_spmm(idtype, g, shp, msg, reducer): ...@@ -129,7 +126,7 @@ def test_spmm(idtype, g, shp, msg, reducer):
with F.record_grad(): with F.record_grad():
g.update_all(udf_msg[msg], udf_reduce[reducer]) g.update_all(udf_msg[msg], udf_reduce[reducer])
if g.number_of_edges() > 0: if g.number_of_edges() > 0:
v1 = F.gather_row(g.dstdata['v'], non_degree_indices) v1 = g.dstdata['v']
assert F.allclose(v, v1) assert F.allclose(v, v1)
print('forward passed') print('forward passed')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment