Unverified Commit 63e2ba23 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[WIP][Kernel] Set the built-in reduce result of zero-degree nodes to 0 in C (#2017)



* test idea

* cuda kernels

* lint and fixes

* lint

* change to another strategy

* use infinity

* fix
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent de2e608b
......@@ -1077,6 +1077,21 @@ def clamp(data, min_val, max_val):
"""
pass
def replace_inf_with_zero(x):
"""Returns a new tensor replacing infinity and negative infinity with zeros.
Parameters
----------
x : Tensor
The input
Returns
-------
Tensor
The result
"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -325,6 +325,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return nd.clip(data, min_val, max_val)
def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def unique(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
......
......@@ -268,6 +268,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return th.clamp(data, min_val, max_val)
def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0)
def unique(input):
if input.dtype == th.bool:
input = input.type(th.int8)
......
......@@ -397,6 +397,9 @@ def clone(input):
def clamp(data, min_val, max_val):
return tf.clip_by_value(data, min_val, max_val)
def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x)
def unique(input):
return tf.unique(input).y
......
"""dgl spmm operator module."""
import sys
from ..base import dgl_warning
from ..backend import gspmm as gspmm_internal, backend_name
from ..backend import gspmm as gspmm_internal
from .. import backend as F
__all__ = ['gspmm']
......@@ -59,28 +58,16 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
new_rhs_shape = (rhs_shape[0],) + (1,) * rhs_pad_ndims + rhs_shape[1:]
lhs_data = F.reshape(lhs_data, new_lhs_shape)
rhs_data = F.reshape(rhs_data, new_rhs_shape)
# With max and min reducers infinity will be returned for zero degree nodes
ret = gspmm_internal(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
lhs_data, rhs_data)
# assign zero features for zero degree nodes.
deg = g.in_degrees()
min_deg = F.as_scalar(F.min(deg, dim=0))
if min_deg == 0:
non_zero_nids = F.nonzero_1d(deg == 0)
if backend_name == 'pytorch':
ret[non_zero_nids] = 0.
else:
dtype = F.dtype(ret)
ctx = F.context(ret)
ret = F.scatter_row(ret, non_zero_nids,
F.zeros((len(non_zero_nids),) + F.shape(ret)[1:], dtype, ctx))
ret = F.replace_inf_with_zero(ret)
# divide in degrees for mean reducer.
if reduce_op == 'mean':
ret_shape = F.shape(ret)
if min_deg == 0:
dgl_warning('Zero-degree nodes encountered in mean reducer. Setting the mean to 0.')
deg = g.in_degrees()
deg = F.astype(F.clamp(deg, 1, g.number_of_edges()), F.dtype(ret))
deg_shape = (ret_shape[0],) + (1,) * (len(ret_shape) - 1)
return ret / F.reshape(deg, deg_shape)
......@@ -88,6 +75,19 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
return ret
def _attach_zerodeg_note(docstring, reducer):
note1 = """
The {} function will return zero for nodes with no incoming messages.""".format(reducer)
note2 = """
This is implemented by replacing all {} values to zero.
""".format("infinity" if reducer == "min" else "negative infinity")
docstring = docstring + note1
if reducer in ('min', 'max'):
docstring = docstring + note2
return docstring
def _gen_spmm_func(binary_op, reduce_op):
name = "u_{}_e_{}".format(binary_op, reduce_op)
docstring = """Generalized SpMM function.
......@@ -120,6 +120,7 @@ def _gen_spmm_func(binary_op, reduce_op):
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics.
""".format(binary_op, reduce_op)
docstring = _attach_zerodeg_note(docstring, reduce_op)
def func(g, x, y):
return gspmm(g, binary_op, reduce_op, x, y)
......@@ -139,7 +140,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
"copy_u": "source node",
"copy_e": "edge"
}
docstring = lambda binary_op: """Generalized SpMM function. {}
docstring = lambda binary_op: _attach_zerodeg_note("""Generalized SpMM function. {}
Then aggregates the message by {} on destination nodes.
Parameters
......@@ -160,7 +161,7 @@ def _gen_copy_reduce_func(binary_op, reduce_op):
""".format(
binary_str[binary_op],
reduce_op,
x_str[binary_op])
x_str[binary_op]), reduce_op)
def func(g, x):
if binary_op == 'copy_u':
......
......@@ -104,8 +104,10 @@ void SpMMSumCoo(
const DType* lhs_off = Op::use_lhs? X + rid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (val != 0) {
#pragma omp atomic
out_off[k] += val;
out_off[k] += val;
}
}
}
}
......@@ -123,8 +125,9 @@ void SpMMSumCoo(
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \note it uses node parallel strategy, different threads are responsible
* \note It uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
* \note The result will contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsr(
......@@ -194,6 +197,7 @@ void SpMMCmpCsr(
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. To avoid possible data hazard,
* we use atomic operators in the reduction phase.
* \note The result will contain infinity for zero-degree nodes.
*/
template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCoo(
......@@ -315,7 +319,7 @@ template <typename DType> constexpr bool CopyRhs<DType>::use_rhs;
//////////////////////////////// Reduce operators on CPU ////////////////////////////////
template <typename DType>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum < val;
......@@ -325,7 +329,7 @@ template <typename DType> constexpr DType Max<DType>::zero;
template <typename DType>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
static constexpr DType zero = std::numeric_limits<DType>::infinity();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) {
return accum > val;
......
......@@ -146,7 +146,7 @@ template <typename Idx,
typename DType,
bool atomic=false>
struct Max {
static constexpr DType zero = std::numeric_limits<DType>::lowest();
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
......@@ -183,7 +183,7 @@ template <typename Idx,
typename DType,
bool atomic=false>
struct Min {
static constexpr DType zero = std::numeric_limits<DType>::max();
static constexpr DType zero = std::numeric_limits<DType>::infinity();
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment