Unverified Commit 3adbfa18 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Performance] Use segment operators for graph readout. (#2361)

* upd

* upd

* update

* upd

* upd

* upd

* fix

* lint

* lint

* pylint

* doc
parent 45e3e9af
......@@ -370,6 +370,23 @@ def reduce_sum(input):
"""
pass
def cumsum(input, dim):
"""Return the cumulative sum of the elements along a given axis.
Parameters
----------
input : Tensor
The input tensor.
dim : int
The cumulative dimension.
Returns
-------
Tensor
A framework-specific tensor.
"""
pass
def mean(input, dim):
"""Reduce average the input tensor along the given dim.
......@@ -1489,6 +1506,31 @@ def edge_softmax(gidx, logits, eids, norm_by):
Tensor
Softmax value
"""
pass
def segment_reduce(op, x, offsets):
"""Segment reduction operator.
It aggregates the value tensor along the first dimension by segments.
The first argument ``seglen`` stores the length of each segment. Its
summation must be equal to the first dimension of the ``value`` tensor.
Zero-length segments are allowed.
Parameters
----------
op : str
Aggregation method. Can be 'sum', 'max', 'min'.
seglen : Tensor
Segment lengths.
value : Tensor
Value to aggregate.
Returns
-------
Tensor
Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.
"""
pass
###############################################################################
......
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
def _scatter_nd(index, src, n_rows):
......@@ -328,3 +328,35 @@ class EdgeSoftmax(mx.autograd.Function):
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
softmax_op = EdgeSoftmax(gidx, eids, norm_by)
return softmax_op(logits)
class SegmentReduce(mx.autograd.Function):
def __init__(self, op, offsets):
super(SegmentReduce, self).__init__()
self.op = op
self.offsets = offsets
def forward(self, x):
y, arg = _segment_reduce(self.op, x, self.offsets)
self.save_for_backward(arg)
return y
def backward(self, dy):
arg, = self.saved_tensors
offsets = self.offsets
m = offsets[-1].asscalar()
if self.op == 'sum':
offsets_np = asnumpy(offsets[1:-1])
indices_np = np.zeros((m,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
indices_np = np.cumsum(indices_np, -1)
indices = zerocopy_from_numpy(indices_np)
dx = dy[indices]
else:
dx = _bwd_segment_cmp(dy, arg, m)
return dx
def segment_reduce(op, x, offsets):
segment_reduce_op = SegmentReduce(op, offsets)
return segment_reduce_op(x)
......@@ -152,6 +152,9 @@ def sum(input, dim, keepdims=False):
def reduce_sum(input):
return input.sum()
def cumsum(input, dim):
return nd.cumsum(input, axis=dim)
def mean(input, dim):
return nd.mean(input, axis=dim)
......
import torch as th
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
def _reduce_grad(grad, shape):
......@@ -231,6 +231,32 @@ class EdgeSoftmax(th.autograd.Function):
return None, grad_score, None, None
class SegmentReduce(th.autograd.Function):
@staticmethod
def forward(ctx, op, x, offsets):
y, arg = _segment_reduce(op, x, offsets)
print(arg)
ctx.save_for_backward(arg, offsets)
ctx.backward_cache = op
return y
@staticmethod
def backward(ctx, dy):
op = ctx.backward_cache
arg, offsets = ctx.saved_tensors
m = offsets[-1].item()
if op == 'sum':
offsets = offsets[1:-1]
indices = th.zeros(
(m,), device=offsets.device, dtype=offsets.dtype)
indices.scatter_add_(0, offsets, th.ones_like(offsets))
indices = th.cumsum(indices, -1)
dx = dy[indices]
else:
dx = _bwd_segment_cmp(dy, arg, m)
return None, dx, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
......@@ -241,3 +267,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets)
......@@ -120,6 +120,9 @@ def sum(input, dim, keepdims=False):
def reduce_sum(input):
return input.sum()
def cumsum(input, dim):
return th.cumsum(input, dim=dim)
def mean(input, dim):
return th.mean(input, dim=dim)
......
import tensorflow as tf
import numpy as np
from .tensor import tensor, copy_to, context
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
def _scatter_nd(index, src, n_rows):
......@@ -254,3 +254,28 @@ def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
return edge_softmax_real(gidx, logits, eids, norm_by)
return _lambda(logits)
def segment_reduce_real(op, x, offsets):
y, arg = _segment_reduce(op, x, offsets)
def segment_reduce_backward(dy):
m = x.shape[0]
if op == 'sum':
offsets_np = asnumpy(offsets[1:-1])
indices_np = np.zeros((m,), dtype=offsets_np.dtype)
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
indices_np = np.cumsum(indices_np, -1)
indices = zerocopy_from_numpy(indices_np)
dx = tf.gather(dy, indices)
else:
dx = _bwd_segment_cmp(dy, arg, m)
return dx
return y, segment_reduce_backward
def segment_reduce(op, x, offsets):
@tf.custom_gradient
def _lambda(x):
return segment_reduce_real(op, x, offsets)
return _lambda(x)
......@@ -175,6 +175,12 @@ def reduce_sum(input):
return tf.reduce_sum(input)
def cumsum(input, dim):
if input.dtype == tf.bool:
input = tf.cast(input, tf.int32)
return tf.cumsum(input, axis=dim)
def mean(input, dim):
return tf.reduce_mean(input, axis=dim)
......
......@@ -2,3 +2,4 @@
from .spmm import *
from .sddmm import *
from .edge_softmax import *
from .segment import *
......@@ -2,8 +2,6 @@
from ..base import DGLError
from .. import backend as F
from .. import convert
from .. import function as fn
def segment_reduce(seglen, value, reducer='sum'):
......@@ -41,20 +39,21 @@ def segment_reduce(seglen, value, reducer='sum'):
[5., 5., 5.],
[4., 4., 4.]])
"""
ctx = F.context(seglen)
# TODO(minjie): a more efficient implementation is to create a graph
# directly from a CSR structure.
u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx),
seglen, dim=0)
if len(u) != len(v):
raise DGLError("Invalid seglen array:", seglen,
". Its summation must be equal to value.shape[0].")
num_nodes = {'_U': len(u), '_V': len(seglen)}
g = convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes)
g.srcdata['h'] = value
g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h'))
return g.dstdata['h']
offsets = F.cumsum(
F.cat([F.zeros((1,), F.dtype(seglen), F.context(seglen)), seglen], 0), 0)
if reducer == 'mean':
rst = F.segment_reduce('sum', value, offsets)
rst_shape = F.shape(rst)
z = F.astype(F.clamp(seglen, 1, len(value)), F.dtype(rst))
z_shape = (rst_shape[0],) + (1,) * (len(rst_shape) - 1)
return rst / F.reshape(z, z_shape)
elif reducer in ['min', 'sum', 'max']:
rst = F.segment_reduce(reducer, value, offsets)
if reducer in ['min', 'max']:
rst = F.replace_inf_with_zero(rst)
return rst
else:
raise DGLError("reducer {} not recognized.".format(reducer))
def segment_softmax(seglen, value):
......
......@@ -248,4 +248,81 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out
def _segment_reduce(op, feat, offsets):
r"""Segment reduction operator.
It aggregates the value tensor along the first dimension by segments.
The first argument ``seglen`` stores the length of each segment. Its
summation must be equal to the first dimension of the ``value`` tensor.
Zero-length segments are allowed.
Parameters
----------
op : str
Aggregation method. Can be 'sum', 'max', 'min'.
seglen : Tensor
Segment lengths.
value : Tensor
Value to aggregate.
Returns
-------
tuple(Tensor)
The first tensor correspond to aggregated tensor of shape
``(len(seglen), value.shape[1:])``, and the second tensor records
the argmin/max at each position for computing gradients.
Notes
-----
This function does not handle gradients.
"""
n = F.shape(offsets)[0] - 1
out_shp = (n,) + F.shape(feat)[1:]
ctx = F.context(feat)
dtype = F.dtype(feat)
idtype = F.dtype(offsets)
out = F.zeros(out_shp, dtype, ctx)
arg = None
if op in ['min', 'max']:
arg = F.zeros(out_shp, idtype, ctx)
arg_nd = to_dgl_nd_for_write(arg)
_CAPI_DGLKernelSegmentReduce(op,
to_dgl_nd(feat),
to_dgl_nd(offsets),
to_dgl_nd_for_write(out),
arg_nd)
arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd)
return out, arg
def _bwd_segment_cmp(feat, arg, m):
r""" Backward phase of segment reduction (for 'min'/'max' reduction).
It computes the gradient of input feature given output gradient of
the segment reduction result.
Parameters
----------
feat : Tensor
The output gradient
arg : Tensor
The ArgMin/Max tensor produced by segment_reduce op.
m : int
The length of input gradients' first dimension.
Returns
-------
Tensor
The input gradient.
"""
out_shp = (m,) + F.shape(feat)[1:]
ctx = F.context(feat)
dtype = F.dtype(feat)
out = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelBwdSegmentCmp(to_dgl_nd(feat),
to_dgl_nd(arg),
to_dgl_nd_for_write(out))
return out
_init_api("dgl.sparse")
/*!
* Copyright (c) 2020 by Contributors
* \file kernel/cpu/segment_reduce.cc
* \brief Segment reduce C APIs and definitions.
*/
#include "./segment_reduce.h"
#include <dgl/array.h>
#include <string>
#include "./spmm_binary_ops.h"
namespace dgl {
namespace aten {
/*! \brief Segment Reduce operator. */
template <int XPU, typename IdType, typename DType>
void SegmentReduce(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
if (op == "sum") {
cpu::SegmentSum<IdType, DType>(feat, offsets, out);
} else if (op == "max" || op == "min") {
if (op == "max")
cpu::SegmentCmp<IdType, DType, cpu::op::Max<DType>>(
feat, offsets, out, arg);
else
cpu::SegmentCmp<IdType, DType, cpu::op::Min<DType>>(
feat, offsets, out, arg);
} else {
LOG(FATAL) << "Unsupported reduce function " << op;
}
}
/*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
NDArray arg,
NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDLCPU, int32_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void BackwardSegmentCmp<kDLCPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/spmm.h
* \brief Segment reduce kernel function header.
*/
#ifndef DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#include <dgl/array.h>
namespace dgl {
namespace aten {
namespace cpu {
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
int n = out->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>();
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
for (int k = 0; k < dim; ++k) {
out_data[i * dim + k] += feat_data[j * dim + k];
}
}
}
}
template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets,
NDArray out, NDArray arg) {
int n = out->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>();
IdType *arg_data = arg.Ptr<IdType>();
std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
std::fill(arg_data, arg_data + arg.NumElements(), -1);
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
for (int k = 0; k < dim; ++k) {
const DType val = feat_data[j * dim + k];
if (Cmp::Call(out_data[i * dim + k], val)) {
out_data[i * dim + k] = val;
arg_data[i * dim + k] = j;
}
}
}
}
}
template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
int n = feat->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>();
const IdType* arg_data = arg.Ptr<IdType>();
DType* out_data = out.Ptr<DType>();
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
for (int k = 0; k < dim; ++k) {
int write_row = arg_data[i * dim + k];
if (write_row >= 0)
out_data[write_row * dim + k] = feat_data[i * dim + k];
}
}
}
} // namespace cpu
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
......@@ -6,6 +6,8 @@
#ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_
#define DGL_ARRAY_CUDA_FUNCTOR_CUH_
#include "./atomic.cuh"
namespace dgl {
namespace aten {
namespace cuda {
......@@ -133,6 +135,15 @@ struct Sum {
cuda::AtomicAdd(out_buf, val);
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
if (!atomic) {
*out_buf += val;
} else {
cuda::AtomicAdd(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {}
......@@ -161,6 +172,18 @@ struct Max {
cuda::AtomicMax(out_buf, val);
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
if (!atomic) {
if (*out_buf < val) {
*out_buf = val;
*arg_buf = id;
}
} else {
cuda::AtomicMax(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
......@@ -198,6 +221,18 @@ struct Min {
cuda::AtomicMin(out_buf, val);
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
if (!atomic) {
if (*out_buf > val) {
*out_buf = val;
*arg_buf = id;
}
} else {
cuda::AtomicMin(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/segment_reduce.cu
* \brief Segment reduce C APIs and definitions.
*/
#include <dgl/array.h>
#include "./segment_reduce.cuh"
#include "./functor.cuh"
namespace dgl {
using namespace cuda;
namespace aten {
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
if (op == "sum") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "max") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "min") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(
feat, offsets, out, arg);
} else {
LOG(FATAL) << "Not implemented";
}
}
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out) {
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDLGPU, int32_t, float>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void BackwardSegmentCmp<kDLGPU, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cuda/segment_reduce.cuh
* \brief Segment reduce kernel function header.
*/
#ifndef DGL_ARRAY_SEGMENT_REDUCE_CUH_
#define DGL_ARRAY_SEGMENT_REDUCE_CUH_
#include "../../runtime/cuda/cuda_common.h"
#include "./atomic.cuh"
#include "./utils.h"
namespace dgl {
using namespace cuda;
namespace aten {
namespace cuda {
/*!
* \brief CUDA kernel of segment reduce.
*/
template <typename IdType, typename DType,
typename ReduceOp>
__global__ void SegmentReduceKernel(
const DType* feat, const IdType* offsets,
DType* out, IdType* arg,
int64_t n, int64_t dim){
int row = blockIdx.x;
int col = blockIdx.y * blockDim.x + threadIdx.x;
if (col < dim) {
DType local_accum = ReduceOp::zero;
IdType local_arg = -1;
for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
}
out[row * dim + col] = local_accum;
if (ReduceOp::require_arg)
arg[row * dim + col] = local_arg;
}
}
/*!
* \brief CUDA kernel of segment reduce.
*/
template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel(
const DType *feat, const IdType *arg, DType *out,
int64_t n, int64_t dim) {
int row = blockIdx.x;
int col = blockIdx.y * blockDim.x + threadIdx.x;
if (col < dim) {
int write_row = arg[row * dim + col];
if (write_row >= 0) {
out[write_row * dim + col] = feat[row * dim + col];
}
}
}
template <typename IdType, typename DType, typename ReduceOp>
void SegmentReduce(
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>();
DType* out_data = out.Ptr<DType>();
IdType* arg_data = arg.Ptr<IdType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t n = out->shape[0];
int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const int nbx = n;
const int ntx = FindNumThreads(dim);
const int nby = (dim + ntx - 1) / ntx;
const int nty = 1;
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>),
nblks, nthrs, 0, thr_entry->stream,
feat_data, offsets_data, out_data, arg_data,
n, dim);
}
template <typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
NDArray arg,
NDArray out) {
const DType* feat_data = feat.Ptr<DType>();
const IdType* arg_data = arg.Ptr<IdType>();
DType *out_data = out.Ptr<DType>();
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t n = feat->shape[0];
int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const int nbx = n;
const int ntx = FindNumThreads(dim);
const int nby = (dim + ntx - 1) / ntx;
const int nty = 1;
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((BackwardSegmentCmpKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
feat_data, arg_data, out_data,
n, dim);
}
} // namespace cuda
} // namespace aten
} // namespace dgl
#endif
......@@ -125,6 +125,32 @@ void SDDMM(const std::string& op,
});
}
/*! \brief Segment reduce dispatch function. */
void SegmentReduceDispatch(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", {
ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Feature data", {
SegmentReduce<XPU, IdType, DType>(op, feat, offsets, out, arg);
});
});
});
}
/*! \brief Backward segment cmp dispatch function.*/
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(feat->dtype, DType, "Feature data", {
BackwardSegmentCmp<XPU, IdType, DType>(feat, arg, out);
});
});
});
}
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
......@@ -174,5 +200,27 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string op = args[0];
NDArray feat = args[1];
NDArray offsets = args[2];
NDArray out = args[3];
NDArray arg = args[4];
CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"});
CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"});
SegmentReduceDispatch(op, feat, offsets, out, arg);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0];
NDArray arg = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"});
CheckContiguous({feat, arg, out}, {"feat", "arg", "out"});
BackwardSegmentCmpDispatch(feat, arg, out);
});
} // namespace aten
} // namespace dgl
......@@ -66,6 +66,24 @@ void SDDMMCoo(const std::string& op,
int lhs_target,
int rhs_target);
/*!
* \brief Segment reduce.
*/
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
/*!
* \brief Backward function of segment cmp.
*/
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out);
} // namespace aten
} // namespace dgl
......
from dgl.ops import gspmm, gsddmm, edge_softmax
from dgl.ops import gspmm, gsddmm, edge_softmax, segment_reduce
from test_utils.graph_cases import get_cases
from utils import parametrize_dtype
import dgl
......@@ -255,5 +255,34 @@ def test_edge_softmax(g, norm_by, shp, idtype):
assert F.allclose(F.grad(e2), grad_edata)
print('backward passed')
@pytest.mark.parametrize('reducer', ['sum', 'max', 'min', 'mean'])
def test_segment_reduce(reducer):
ctx = F.ctx()
value = F.tensor(np.random.rand(10, 5))
v1 = F.attach_grad(F.clone(value))
v2 = F.attach_grad(F.clone(value))
seglen = F.tensor([2, 3, 0, 4, 1])
u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx),
seglen, dim=0)
num_nodes = {'_U': len(u), '_V': len(seglen)}
g = dgl.convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes)
with F.record_grad():
rst1 = gspmm(g, 'copy_lhs', reducer, v1, None)
F.backward(F.reduce_sum(rst1))
grad1 = F.grad(v1)
with F.record_grad():
rst2 = segment_reduce(seglen, v2, reducer=reducer)
F.backward(F.reduce_sum(rst2))
assert F.allclose(rst1, rst2)
print('forward passed')
grad2 = F.grad(v2)
assert F.allclose(grad1, grad2)
print('backward passed')
if __name__ == '__main__':
test_spmm(F.int32, graphs[0], spmm_shapes[5], 'copy_lhs', 'sum')
test_spmm(F.int32, graphs[0], spmm_shapes[0], 'mul', 'sum')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment