Unverified Commit e6bf54cd authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Support aggregate multiple edge features in to_simple. (#2623)

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* fix

* refactor

* upd test

* large feat_len or n in segment reduce

* lint
parent 8a07ab77
...@@ -1552,6 +1552,28 @@ def segment_reduce(op, x, offsets): ...@@ -1552,6 +1552,28 @@ def segment_reduce(op, x, offsets):
""" """
pass pass
def scatter_add(x, idx, m):
"""Scatter add (on first dimension) operator.
Math: y[idx[i], *] += x[i, *]
Parameters
----------
x : Tensor
The input feature.
idx : Tensor
The indices array.
m : int
The length of output.
Returns
-------
Tensor
The output tensor.
"""
pass
############################################################################### ###############################################################################
# Other interfaces # Other interfaces
......
import mxnet as mx import mxnet as mx
import numpy as np import numpy as np
from mxnet import nd from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...base import dgl_warning, is_all, ALL from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -360,3 +360,22 @@ class SegmentReduce(mx.autograd.Function): ...@@ -360,3 +360,22 @@ class SegmentReduce(mx.autograd.Function):
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
segment_reduce_op = SegmentReduce(op, offsets) segment_reduce_op = SegmentReduce(op, offsets)
return segment_reduce_op(x) return segment_reduce_op(x)
class ScatterAdd(mx.autograd.Function):
def __init__(self, idx, m):
super(ScatterAdd, self).__init__()
self.idx = idx
self.m = m
def forward(self, x):
y = _scatter_add(x, self.idx, self.m)
return y
def backward(self, dy):
return dy[self.idx]
def scatter_add(x, idx, m):
scatter_add_op = ScatterAdd(idx, m)
return scatter_add_op(x)
import torch as th import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
if LooseVersion(th.__version__) >= LooseVersion("1.6.0"): if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import custom_fwd, custom_bwd from torch.cuda.amp import custom_fwd, custom_bwd
...@@ -24,7 +24,7 @@ else: ...@@ -24,7 +24,7 @@ else:
return bwd(*args, **kwargs) return bwd(*args, **kwargs)
return decorate_bwd return decorate_bwd
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
...@@ -288,6 +288,21 @@ class SegmentReduce(th.autograd.Function): ...@@ -288,6 +288,21 @@ class SegmentReduce(th.autograd.Function):
return None, dx, None return None, dx, None
class ScatterAdd(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, x, idx, m):
y = _scatter_add(x, idx, m)
ctx.save_for_backward(idx)
return y
@staticmethod
@custom_bwd
def backward(ctx, dy):
idx = ctx.saved_tensors
return dy[idx], None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data) return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
...@@ -302,3 +317,6 @@ def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'): ...@@ -302,3 +317,6 @@ def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
return SegmentReduce.apply(op, x, offsets) return SegmentReduce.apply(op, x, offsets)
def scatter_add(x, idx, m):
return ScatterAdd.apply(x, idx, m)
...@@ -2,9 +2,9 @@ import tensorflow as tf ...@@ -2,9 +2,9 @@ import tensorflow as tf
import numpy as np import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _scatter_add
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce'] __all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce', 'scatter_add']
def _scatter_nd(index, src, n_rows): def _scatter_nd(index, src, n_rows):
...@@ -279,3 +279,19 @@ def segment_reduce(op, x, offsets): ...@@ -279,3 +279,19 @@ def segment_reduce(op, x, offsets):
def _lambda(x): def _lambda(x):
return segment_reduce_real(op, x, offsets) return segment_reduce_real(op, x, offsets)
return _lambda(x) return _lambda(x)
def scatter_add_real(x, idx, m):
y = _scatter_add(x, idx, m)
def scatter_add_backward(dy):
return tf.gather(dy, idx)
return y, scatter_add_backward
def scatter_add(x, idx, m):
@tf.custom_gradient
def _lambda(x):
return scatter_add_real(x, idx, m)
return _lambda(x)
"""Module for sparse matrix operators.""" """Module for sparse matrix operators."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from __future__ import absolute_import from __future__ import absolute_import
import dgl.ndarray as nd import dgl.ndarray as nd
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
...@@ -309,6 +308,35 @@ def _segment_reduce(op, feat, offsets): ...@@ -309,6 +308,35 @@ def _segment_reduce(op, feat, offsets):
return out, arg return out, arg
def _scatter_add(x, idx, m):
r""" Scatter add operator (on first dimension) implementation.
Math: y[idx[i], *] += x[i, *]
Parameters
----------
x : Tensor
The input feature.
idx : Tensor
The indices array.
m : int
The length of output.
Returns
-------
Tensor
The output tensor.
"""
out_shp = (m,) + F.shape(x)[1:]
ctx = F.context(x)
dtype = F.dtype(x)
out = F.zeros(out_shp, dtype, ctx)
_CAPI_DGLKernelScatterAdd(to_dgl_nd(x),
to_dgl_nd(idx),
to_dgl_nd_for_write(out))
return out
def _bwd_segment_cmp(feat, arg, m): def _bwd_segment_cmp(feat, arg, m):
r""" Backward phase of segment reduction (for 'min'/'max' reduction). r""" Backward phase of segment reduction (for 'min'/'max' reduction).
......
...@@ -9,6 +9,7 @@ from ._ffi.function import _init_api ...@@ -9,6 +9,7 @@ from ._ffi.function import _init_api
from .base import dgl_warning, DGLError from .base import dgl_warning, DGLError
from . import convert from . import convert
from .heterograph import DGLHeteroGraph, DGLBlock from .heterograph import DGLHeteroGraph, DGLBlock
from .frame import Frame
from . import ndarray as nd from . import ndarray as nd
from . import backend as F from . import backend as F
from . import utils, batch from . import utils, batch
...@@ -1841,11 +1842,63 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -1841,11 +1842,63 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
return new_graph return new_graph
def _coalesce_edge_frame(g, edge_maps, counts, aggregator):
r"""Coalesce edge features of duplicate edges via given aggregator in g.
Parameters
----------
g : DGLGraph
The input graph.
edge_maps : List[Tensor]
The edge mapping corresponding to each edge type in g.
counts : List[Tensor]
The number of duplicated edges from the original graph for each edge type.
aggregator : str
Indicates how to coalesce edge features, could be ``arbitrary``, ``sum``
or ``mean``.
Returns
-------
List[Frame]
The frames corresponding to each edge type.
"""
if aggregator == 'arbitrary':
eids = []
for i in range(len(g.canonical_etypes)):
feat_idx = F.asnumpy(edge_maps[i])
_, indices = np.unique(feat_idx, return_index=True)
eids.append(F.zerocopy_from_numpy(indices))
edge_frames = utils.extract_edge_subframes(g, eids)
elif aggregator in ['sum', 'mean']:
edge_frames = []
for i in range(len(g.canonical_etypes)):
feat_idx = edge_maps[i]
_, indices = np.unique(F.asnumpy(feat_idx), return_index=True)
_num_rows = len(indices)
_data = {}
for key, col in g._edge_frames[i]._columns.items():
data = col.data
new_data = F.scatter_add(data, feat_idx, _num_rows)
if aggregator == 'mean':
norm = F.astype(counts[i], F.dtype(data))
norm = F.reshape(norm, (F.shape(norm)[0],) + (1,) * (F.ndim(data) - 1))
new_data /= norm
_data[key] = new_data
newf = Frame(data=_data, num_rows=_num_rows)
edge_frames.append(newf)
else:
raise DGLError("Aggregator {} not regonized, cannot coalesce edge feature in the "
"specified way".format(aggregator))
return edge_frames
def to_simple(g, def to_simple(g,
return_counts='count', return_counts='count',
writeback_mapping=False, writeback_mapping=False,
copy_ndata=True, copy_ndata=True,
copy_edata=False): copy_edata=False,
aggregator='arbitrary'):
r"""Convert a graph to a simple graph without parallel edges and return. r"""Convert a graph to a simple graph without parallel edges and return.
For a heterogeneous graph with multiple edge types, DGL treats edges with the same For a heterogeneous graph with multiple edge types, DGL treats edges with the same
...@@ -1886,12 +1939,19 @@ def to_simple(g, ...@@ -1886,12 +1939,19 @@ def to_simple(g,
copy_edata: bool, optional copy_edata: bool, optional
If True, the edge features of the simple graph are copied If True, the edge features of the simple graph are copied
from the original graph. If there exists duplicate edges between from the original graph. If there exists duplicate edges between
two nodes (u, v), the feature of the edge is randomly selected two nodes (u, v), the feature of the edge is the aggregation
from one of the duplicate edges. of edge feature of duplicate edges.
If False, the simple graph will not have any edge features. If False, the simple graph will not have any edge features.
(Default: False) (Default: False)
aggregator: str, optional
Indicate how to coalesce edge feature of duplicate edges.
If ``arbitrary``, select one of the duplicate edges' feature.
If ``sum``, compute the summation of duplicate edges' feature.
If ``mean``, compute the average of duplicate edges' feature.
(Default: ``arbitrary``)
Returns Returns
------- -------
...@@ -1991,14 +2051,8 @@ def to_simple(g, ...@@ -1991,14 +2051,8 @@ def to_simple(g,
node_frames = utils.extract_node_subframes(g, None) node_frames = utils.extract_node_subframes(g, None)
utils.set_new_frames(simple_graph, node_frames=node_frames) utils.set_new_frames(simple_graph, node_frames=node_frames)
if copy_edata: if copy_edata:
eids = [] new_edge_frames = _coalesce_edge_frame(g, edge_maps, counts, aggregator)
for i in range(len(g.canonical_etypes)): utils.set_new_frames(simple_graph, edge_frames=new_edge_frames)
feat_idx = F.asnumpy(edge_maps[i])
_, indices = np.unique(feat_idx, return_index=True)
eids.append(F.zerocopy_from_numpy(indices))
edge_frames = utils.extract_edge_subframes(g, eids)
utils.set_new_frames(simple_graph, edge_frames=edge_frames)
if return_counts is not None: if return_counts is not None:
for count, canonical_etype in zip(counts, g.canonical_etypes): for count, canonical_etype in zip(counts, g.canonical_etypes):
......
...@@ -40,6 +40,16 @@ void SegmentReduce( ...@@ -40,6 +40,16 @@ void SegmentReduce(
} }
} }
/*! \brief Scatter Add.*/
template <int XPU, typename IdType, int bits>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
SWITCH_BITS(bits, DType, {
cpu::ScatterAdd<IdType, DType>(feat, idx, out);
});
}
/*! \brief Backward function of segment cmp.*/ /*! \brief Backward function of segment cmp.*/
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp( void BackwardSegmentCmp(
...@@ -87,6 +97,30 @@ template void SegmentReduce<kDLCPU, int64_t, 64>( ...@@ -87,6 +97,30 @@ template void SegmentReduce<kDLCPU, int64_t, 64>(
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDLCPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>( template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
......
...@@ -72,8 +72,35 @@ void SegmentCmp(NDArray feat, NDArray offsets, ...@@ -72,8 +72,35 @@ void SegmentCmp(NDArray feat, NDArray offsets,
} }
} }
/*!
* \brief CPU kernel of Scatter Add (on first dimension) operator.
* \note math equation: out[idx[i], *] += feat[i, *]
* \param feat The input tensor.
* \param idx The indices tensor.
* \param out The output tensor.
*/
template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
int n = feat->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>();
DType* out_data = out.Ptr<DType>();
#pragma omp parallel for
for (int i = 0; i < n; ++i) {
const int write_row = idx_data[i];
for (int k = 0; k < dim; ++k) {
#pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k];
}
}
}
/*! /*!
* \brief CPU kernel of backward phase of segment min/max. * \brief CPU kernel of backward phase of segment min/max.
* \note math equation: out[arg[i, k], k] = feat[i, k]
* \param feat The input tensor. * \param feat The input tensor.
* \param arg The argmin/argmax tensor. * \param arg The argmin/argmax tensor.
* \param out The output tensor. * \param out The output tensor.
......
...@@ -38,6 +38,16 @@ void SegmentReduce(const std::string& op, ...@@ -38,6 +38,16 @@ void SegmentReduce(const std::string& op,
} }
template <int XPU, typename IdType, int bits>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
SWITCH_BITS(bits, DType, {
cuda::ScatterAdd<IdType, DType>(feat, idx, out);
});
}
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat,
NDArray arg, NDArray arg,
...@@ -84,6 +94,30 @@ template void SegmentReduce<kDLGPU, int64_t, 64>( ...@@ -84,6 +94,30 @@ template void SegmentReduce<kDLGPU, int64_t, 64>(
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDLGPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>( template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./atomic.cuh"
namespace dgl { namespace dgl {
...@@ -27,17 +28,38 @@ __global__ void SegmentReduceKernel( ...@@ -27,17 +28,38 @@ __global__ void SegmentReduceKernel(
const DType* feat, const IdType* offsets, const DType* feat, const IdType* offsets,
DType* out, IdType* arg, DType* out, IdType* arg,
int64_t n, int64_t dim){ int64_t n, int64_t dim){
int row = blockIdx.x; for (int row = blockIdx.x; row < n; row += gridDim.x) {
int col = blockIdx.y * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.x + threadIdx.x;
if (col < dim) { while (col < dim) {
DType local_accum = ReduceOp::zero(); DType local_accum = ReduceOp::zero();
IdType local_arg = -1; IdType local_arg = -1;
for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) { for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i); ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
}
out[row * dim + col] = local_accum;
if (ReduceOp::require_arg)
arg[row * dim + col] = local_arg;
col += gridDim.y * blockDim.x;
}
}
}
/*!
* \brief CUDA kernel of scatter add.
* \note each blockthread is responsible for adding a row in feature tensor
* to a target row in output tensor.
*/
template <typename IdType, typename DType>
__global__ void ScatterAddKernel(
const DType *feat, const IdType *idx, DType *out,
int64_t n, int64_t dim) {
for (int row = blockIdx.x; row < n; row += gridDim.x) {
const int write_row = idx[row];
int col = blockIdx.y * blockDim.x + threadIdx.x;
while (col < dim) {
cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
col += gridDim.y * blockDim.x;
} }
out[row * dim + col] = local_accum;
if (ReduceOp::require_arg)
arg[row * dim + col] = local_arg;
} }
} }
...@@ -50,12 +72,14 @@ template <typename IdType, typename DType> ...@@ -50,12 +72,14 @@ template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel( __global__ void BackwardSegmentCmpKernel(
const DType *feat, const IdType *arg, DType *out, const DType *feat, const IdType *arg, DType *out,
int64_t n, int64_t dim) { int64_t n, int64_t dim) {
int row = blockIdx.x; for (int row = blockIdx.x; row < n; row += gridDim.x) {
int col = blockIdx.y * blockDim.x + threadIdx.x; int col = blockIdx.y * blockDim.x + threadIdx.x;
if (col < dim) { while (col < dim) {
int write_row = arg[row * dim + col]; int write_row = arg[row * dim + col];
if (write_row >= 0) { if (write_row >= 0) {
out[write_row * dim + col] = feat[row * dim + col]; out[write_row * dim + col] = feat[row * dim + col];
}
col += gridDim.y * blockDim.x;
} }
} }
} }
...@@ -84,9 +108,9 @@ void SegmentReduce( ...@@ -84,9 +108,9 @@ void SegmentReduce(
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i]; dim *= out->shape[i];
const int nbx = n; const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim); const int ntx = FindNumThreads(dim);
const int nby = (dim + ntx - 1) / ntx; const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
const int nty = 1; const int nty = 1;
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
...@@ -97,8 +121,43 @@ void SegmentReduce( ...@@ -97,8 +121,43 @@ void SegmentReduce(
n, dim); n, dim);
} }
/*!
* \brief CUDA implementation of Scatter Add (on first dimension).
* \note math equation: out[idx[i], *] += feat[i, *]
* \param feat The input tensor.
* \param idx The indices tensor.
* \param out The output tensor.
*/
template <typename IdType, typename DType>
void ScatterAdd(
NDArray feat,
NDArray idx,
NDArray out) {
const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>();
DType *out_data = out.Ptr<DType>();
auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t n = feat->shape[0];
int64_t dim = 1;
for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i];
const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim);
const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
const int nty = 1;
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
CUDA_KERNEL_CALL((ScatterAddKernel<IdType, DType>),
nblks, nthrs, 0, thr_entry->stream,
feat_data, idx_data, out_data,
n, dim);
}
/*! /*!
* \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer. * \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
* \note math equation: out[arg[i, k], k] = feat[i, k]
* \param feat The input tensor. * \param feat The input tensor.
* \param arg The ArgMin/Max information, used for indexing. * \param arg The ArgMin/Max information, used for indexing.
* \param out The output tensor. * \param out The output tensor.
...@@ -118,9 +177,9 @@ void BackwardSegmentCmp( ...@@ -118,9 +177,9 @@ void BackwardSegmentCmp(
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i)
dim *= out->shape[i]; dim *= out->shape[i];
const int nbx = n; const int nbx = FindNumBlocks<'x'>(n);
const int ntx = FindNumThreads(dim); const int ntx = FindNumThreads(dim);
const int nby = (dim + ntx - 1) / ntx; const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
const int nty = 1; const int nty = 1;
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
......
...@@ -153,6 +153,17 @@ void SegmentReduceDispatch(const std::string& op, ...@@ -153,6 +153,17 @@ void SegmentReduceDispatch(const std::string& op,
}); });
} }
/*! \brief Scatter Add (on first dimension) dispatch function. */
void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", {
ScatterAdd<XPU, IdType, bits>(feat, idx, out);
});
});
});
}
/*! \brief Backward segment cmp dispatch function.*/ /*! \brief Backward segment cmp dispatch function.*/
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
...@@ -225,6 +236,16 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce") ...@@ -225,6 +236,16 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce")
SegmentReduceDispatch(op, feat, offsets, out, arg); SegmentReduceDispatch(op, feat, offsets, out, arg);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd")
.set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0];
NDArray idx = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
ScatterAddDispatch(feat, idx, out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0]; NDArray feat = args[0];
......
...@@ -76,6 +76,14 @@ void SegmentReduce(const std::string& op, ...@@ -76,6 +76,14 @@ void SegmentReduce(const std::string& op,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
/*!
* \brief Scatter Add on first dimension.
*/
template <int XPU, typename IdType, int bits>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out);
/*! /*!
* \brief Backward function of segment cmp. * \brief Backward function of segment cmp.
*/ */
......
...@@ -772,6 +772,14 @@ def test_to_simple(idtype): ...@@ -772,6 +772,14 @@ def test_to_simple(idtype):
assert 'h' not in sg.ndata assert 'h' not in sg.ndata
assert 'h' not in sg.edata assert 'h' not in sg.edata
# test coalesce edge feature
sg = dgl.to_simple(g, copy_edata=True, aggregator='arbitrary')
assert F.allclose(sg.edata['h'][1], F.tensor([4.]))
sg = dgl.to_simple(g, copy_edata=True, aggregator='sum')
assert F.allclose(sg.edata['h'][1], F.tensor([10.]))
sg = dgl.to_simple(g, copy_edata=True, aggregator='mean')
assert F.allclose(sg.edata['h'][1], F.tensor([5.]))
# heterogeneous graph # heterogeneous graph
g = dgl.heterograph({ g = dgl.heterograph({
('user', 'follow', 'user'): ([0, 1, 2, 1, 1, 1], ('user', 'follow', 'user'): ([0, 1, 2, 1, 1, 1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment