Unverified Commit 6b02babb authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[doc] Add docstring for segment reduce. (#2375)

parent 35a3ead2
......@@ -239,7 +239,7 @@ Like GSpMM, GSDDMM operators support both homogeneous and bipartite graph.
Edge Softmax module
-------------------
We also provide framework agnostic edge softmax module which was frequently used in
DGL also provide framework agnostic edge softmax module which was frequently used in
GNN-like structures, e.g.
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`_,
`Transformer <https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_,
......@@ -250,6 +250,16 @@ GNN-like structures, e.g.
edge_softmax
Segment Reduce Module
---------------------
DGL provide operators to reduce value tensor along the first dimension by segments.
.. autosummary::
:toctree: ../../generated/
segment_reduce
Relation with Message Passing APIs
----------------------------------
......
......@@ -1512,23 +1512,27 @@ def segment_reduce(op, x, offsets):
"""Segment reduction operator.
It aggregates the value tensor along the first dimension by segments.
The first argument ``seglen`` stores the length of each segment. Its
summation must be equal to the first dimension of the ``value`` tensor.
Zero-length segments are allowed.
The argument ``offsets`` specifies the start offset of each segment (and
the upper bound of the last segment). Zero-length segments are allowed.
.. math::
y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j
where :math:`\Phi` is the reduce operator.
Parameters
----------
op : str
Aggregation method. Can be 'sum', 'max', 'min'.
seglen : Tensor
Segment lengths.
value : Tensor
Aggregation method. Can be ``sum``, ``max``, ``min``.
x : Tensor
Value to aggregate.
offsets : Tensor
The start offsets of segments.
Returns
-------
Tensor
Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.
Aggregated tensor of shape ``(len(offsets) - 1, value.shape[1:])``.
"""
pass
......
......@@ -69,8 +69,6 @@ def segment_softmax(seglen, value):
Segment lengths.
value : Tensor
Value to aggregate.
reducer : str, optional
Aggregation method. Can be 'sum', 'max', 'min', 'mean'.
Returns
-------
......
......@@ -252,18 +252,22 @@ def _segment_reduce(op, feat, offsets):
r"""Segment reduction operator.
It aggregates the value tensor along the first dimension by segments.
The first argument ``seglen`` stores the length of each segment. Its
summation must be equal to the first dimension of the ``value`` tensor.
Zero-length segments are allowed.
The argument ``offsets`` specifies the start offset of each segment (and
the upper bound of the last segment). Zero-length segments are allowed.
.. math::
y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j
where :math:`\Phi` is the reduce operator.
Parameters
----------
op : str
Aggregation method. Can be 'sum', 'max', 'min'.
seglen : Tensor
Segment lengths.
value : Tensor
Aggregation method. Can be ``sum``, ``max``, ``min``.
x : Tensor
Value to aggregate.
offsets : Tensor
The start offsets of segments.
Returns
-------
......
......@@ -12,6 +12,12 @@ namespace dgl {
namespace aten {
namespace cpu {
/*!
* \brief CPU kernel of segment sum.
* \param feat The input tensor.
* \param offsets The offset tensor storing the ranges of segments.
* \param out The output tensor.
*/
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
int n = out->shape[0];
......@@ -31,6 +37,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
}
}
/*!
* \brief CPU kernel of segment min/max.
* \param feat The input tensor.
* \param offsets The offset tensor storing the ranges of segments.
* \param out The output tensor.
* \param arg An auxiliary tensor storing the argmin/max information
* used in backward phase.
*/
template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets,
NDArray out, NDArray arg) {
......@@ -58,6 +72,12 @@ void SegmentCmp(NDArray feat, NDArray offsets,
}
}
/*!
* \brief CPU kernel of backward phase of segment min/max.
* \param feat The input tensor.
* \param arg The argmin/argmax tensor.
* \param out The output tensor.
*/
template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
int n = feat->shape[0];
......
......@@ -19,6 +19,8 @@ namespace cuda {
/*!
* \brief CUDA kernel of segment reduce.
* \note each blockthread is responsible for aggregation on a row
* in the result tensor.
*/
template <typename IdType, typename DType,
typename ReduceOp>
......@@ -41,7 +43,9 @@ __global__ void SegmentReduceKernel(
}
/*!
* \brief CUDA kernel of segment reduce.
* \brief CUDA kernel of backward phase in segment min/max.
* \note each blockthread is responsible for writing a row in the
* result gradient tensor by lookup the ArgMin/Max for index information.
*/
template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel(
......@@ -57,6 +61,13 @@ __global__ void BackwardSegmentCmpKernel(
}
}
/*!
* \brief CUDA implementation of forward phase of Segment Reduce.
* \param feat The input tensor.
* \param offsets The offsets tensor.
* \param out The output tensor.
* \param arg An auxiliary tensor storing ArgMax/Min information,
*/
template <typename IdType, typename DType, typename ReduceOp>
void SegmentReduce(
NDArray feat,
......@@ -80,12 +91,19 @@ void SegmentReduce(
const int nty = 1;
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
// TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>),
nblks, nthrs, 0, thr_entry->stream,
feat_data, offsets_data, out_data, arg_data,
n, dim);
}
/*!
* \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
* \param feat The input tensor.
* \param arg The ArgMin/Max information, used for indexing.
* \param out The output tensor.
*/
template <typename IdType, typename DType>
void BackwardSegmentCmp(
NDArray feat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment