Unverified Commit 63e2ba23 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[WIP][Kernel] Set the built-in reduce result of zero-degree nodes to 0 in C (#2017)



* test idea

* cuda kernels

* lint and fixes

* lint

* change to another strategy

* use infinity

* fix
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent de2e608b
...@@ -1077,6 +1077,21 @@ def clamp(data, min_val, max_val): ...@@ -1077,6 +1077,21 @@ def clamp(data, min_val, max_val):
""" """
pass pass
def replace_inf_with_zero(x):
"""Returns a new tensor replacing infinity and negative infinity with zeros.
Parameters
----------
x : Tensor
The input
Returns
-------
Tensor
The result
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -325,6 +325,9 @@ def clone(input): ...@@ -325,6 +325,9 @@ def clone(input):
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val) return nd.clip(data, min_val, max_val)
def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -268,6 +268,9 @@ def clone(input): ...@@ -268,6 +268,9 @@ def clone(input):
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val) return th.clamp(data, min_val, max_val)
def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0)
def unique(input): def unique(input):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
......
...@@ -397,6 +397,9 @@ def clone(input): ...@@ -397,6 +397,9 @@ def clone(input):
def clamp(data, min_val, max_val): def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val) return tf.clip_by_value(data, min_val, max_val)
def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x)
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
"""dgl spmm operator module.""" """dgl spmm operator module."""
import sys import sys
from ..base import dgl_warning from ..backend import gspmm as gspmm_internal
from ..backend import gspmm as gspmm_internal, backend_name
from .. import backend as F from .. import backend as F
__all__ = ['gspmm'] __all__ = ['gspmm']
...@@ -59,28 +58,16 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -59,28 +58,16 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:] new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape) lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape) rhs_data = F.reshape(rhs_data, new_rhs_shape)
# With max and min reducers infinity will be returned for zero degree nodes
ret = gspmm_internal(g._graph, op, ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op, 'sum' if reduce_op == 'mean' else reduce_op,
lhs_data, rhs_data) lhs_data, rhs_data)
ret = F.replace_inf_with_zero(ret)
# assign zero features for zero degree nodes.
deg = g.in_degrees()
min_deg = F.as_scalar(F.min(deg, dim=0))
if min_deg == 0:
non_zero_nids = F.nonzero_1d(deg == 0)
if backend_name == 'pytorch':
ret[non_zero_nids] = 0.
else:
dtype = F.dtype(ret)
ctx = F.context(ret)
ret = F.scatter_row(ret, non_zero_nids,
F.zeros((len(non_zero_nids),) + F.shape(ret)[1:], dtype, ctx))
# divide in degrees for mean reducer. # divide in degrees for mean reducer.
if reduce_op == 'mean': if reduce_op == 'mean':
ret_shape = F.shape(ret) ret_shape = F.shape(ret)
if min_deg == 0: deg = g.in_degrees()
dgl_warning('Zero-degree nodes encountered in mean reducer. Setting the mean to 0.')
deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret)) deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret))
deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1) deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1)
return ret / F.reshape(deg, deg_shape) return ret / F.reshape(deg, deg_shape)
...@@ -88,6 +75,19 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data): ...@@ -88,6 +75,19 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
return ret return ret
def _attach_zerodeg_note(docstring, reducer):
note1 = """
The {} function will return zero for nodes with no incoming messages.""".format(reducer)
note2 = """
This is implemented by replacing all {} values to zero.
""".format("infinity" if reducer == "min" else "negative infinity")
docstring = docstring + note1
if reducer in ('min', 'max'):
docstring = docstring + note2
return docstring
def _gen_spmm_func(binary_op, reduce_op): def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op) name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function. docstring = """Generalized SpMM function.
...@@ -120,6 +120,7 @@ def _gen_spmm_func(binary_op, reduce_op): ...@@ -120,6 +120,7 @@ def _gen_spmm_func(binary_op, reduce_op):
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics. for more details about the NumPy broadcasting semantics.
""".format(binary_op, reduce_op) """.format(binary_op, reduce_op)
docstring = _attach_zerodeg_note(docstring, reduce_op)
def func(g, x, y): def func(g, x, y):
return gspmm(g, binary_op, reduce_op, x, y) return gspmm(g, binary_op, reduce_op, x, y)
...@@ -139,7 +140,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op): ...@@ -139,7 +140,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
"copy_u": "source node", "copy_u": "source node",
"copy_e": "edge" "copy_e": "edge"
} }
docstring = lambda binary_op: """Generalized SpMM function. {} docstring = lambda binary_op: _attach_zerodeg_note("""Generalized SpMM function. {}
Then aggregates the message by {} on destination nodes. Then aggregates the message by {} on destination nodes.
Parameters Parameters
...@@ -160,7 +161,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op): ...@@ -160,7 +161,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
""".format( """.format(
binary_str[binary_op], binary_str[binary_op],
reduce_op, reduce_op,
x_str[binary_op]) x_str[binary_op]), reduce_op)
def func(g, x): def func(g, x):
if binary_op == 'copy_u': if binary_op == 'copy_u':
......
...@@ -104,8 +104,10 @@ void SpMMSumCoo( ...@@ -104,8 +104,10 @@ void SpMMSumCoo(
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr; const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr; const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off); const DType val = Op::Call(lhs_off, rhs_off);
if (val != 0) {
#pragma omp atomic #pragma omp atomic
out_off[k] += val; out_off[k] += val;
}
} }
} }
} }
...@@ -123,8 +125,9 @@ void SpMMSumCoo( ...@@ -123,8 +125,9 @@ void SpMMSumCoo(
* \param arge Arg-Min/Max on edges. which refers the source node indices * \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max reducer.
* \note it uses node parallel strategy, different threads are responsible * \note It uses node parallel strategy, different threads are responsible
* for the computation of different nodes. * for the computation of different nodes.
* \note The result will contain infinity for zero-degree nodes.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsr( void SpMMCmpCsr(
...@@ -194,6 +197,7 @@ void SpMMCmpCsr( ...@@ -194,6 +197,7 @@ void SpMMCmpCsr(
* \note it uses node parallel strategy, different threads are responsible * \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. To avoid possible data hazard, * for the computation of different nodes. To avoid possible data hazard,
* we use atomic operators in the reduction phase. * we use atomic operators in the reduction phase.
* \note The result will contain infinity for zero-degree nodes.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCoo( void SpMMCmpCoo(
...@@ -315,7 +319,7 @@ template <typename DType> constexpr bool CopyRhs<DType>::use_rhs; ...@@ -315,7 +319,7 @@ template <typename DType> constexpr bool CopyRhs<DType>::use_rhs;
//////////////////////////////// Reduce operators on CPU //////////////////////////////// //////////////////////////////// Reduce operators on CPU ////////////////////////////////
template <typename DType> template <typename DType>
struct Max { struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest(); static constexpr DType zero = -std::numeric_limits<DType>::infinity();
// return true if accum should be replaced // return true if accum should be replaced
inline static DType Call(DType accum, DType val) { inline static DType Call(DType accum, DType val) {
return accum < val; return accum < val;
...@@ -325,7 +329,7 @@ template <typename DType> constexpr DType Max<DType>::zero; ...@@ -325,7 +329,7 @@ template <typename DType> constexpr DType Max<DType>::zero;
template <typename DType> template <typename DType>
struct Min { struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max(); static constexpr DType zero = std::numeric_limits<DType>::infinity();
// return true if accum should be replaced // return true if accum should be replaced
inline static DType Call(DType accum, DType val) { inline static DType Call(DType accum, DType val) {
return accum > val; return accum > val;
......
...@@ -146,7 +146,7 @@ template <typename Idx, ...@@ -146,7 +146,7 @@ template <typename Idx,
typename DType, typename DType,
bool atomic=false> bool atomic=false>
struct Max { struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest(); static constexpr DType zero = -std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
...@@ -183,7 +183,7 @@ template <typename Idx, ...@@ -183,7 +183,7 @@ template <typename Idx,
typename DType, typename DType,
bool atomic=false> bool atomic=false>
struct Min { struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max(); static constexpr DType zero = std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment