Unverified Commit 92a3d07d authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Kernel] Use tree reduction for SDDMM-dot (#2335)



* multiple fixes

* fix CI

* fiddle

* revert stubs

* upd

* upd

* unmerge

* unmerge
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent fe4ada23
...@@ -21,6 +21,8 @@ using namespace cuda; ...@@ -21,6 +21,8 @@ using namespace cuda;
namespace aten { namespace aten {
namespace cuda { namespace cuda {
constexpr unsigned int full_mask = 0xffffffff;
/*! /*!
* \brief CUDA kernel of g-SDDMM on Coo format. * \brief CUDA kernel of g-SDDMM on Coo format.
* \note it uses edge parallel strategy, different threadblocks (on y-axis) * \note it uses edge parallel strategy, different threadblocks (on y-axis)
...@@ -70,6 +72,51 @@ __global__ void SDDMMCooKernel( ...@@ -70,6 +72,51 @@ __global__ void SDDMMCooKernel(
} }
} }
/*!
* \brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree reduction.
* \note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
*/
template <typename Idx, typename DType,
bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooTreeReduceKernel(
const DType* __restrict__ lhs,
const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
if (ty < E) {
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
const DType* rhsoff = rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
DType* outoff = out + eid * out_len;
int tx = threadIdx.x; // tx < 32
for (int i = blockIdx.y; i < out_len; i += gridDim.y) { // over output feature dimension
const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
DType val = 0.;
for (int j = tx; j < reduce_size; j += 32)
val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j];
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(full_mask, val, offset);
if (tx == 0)
outoff[i] = val;
}
}
}
// Binary search the row_offsets to find the source node of the edge id. // Binary search the row_offsets to find the source node of the edge id.
template <typename Idx> template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) { __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) {
...@@ -172,15 +219,31 @@ void SDDMMCoo( ...@@ -172,15 +219,31 @@ void SDDMMCoo(
int64_t reduce_dim = bcast.reduce_size; int64_t reduce_dim = bcast.reduce_size;
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const bool use_idx = !IsNullArray(coo.data);
if (std::is_same<Op, binary::Dot<DType> >::value && reduce_dim >= 32) {
const int ntx = 32; // on feature dimension
const int nty = 8; // on out dimension
const int nbx = (nnz + nty - 1) / nty;
const int nby = FindNumBlocks<'y'>(len);
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, thr_entry->stream,
lhs_data, rhs_data, out_data,
row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
});
} else {
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx; const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx; const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty); const int nby = FindNumBlocks<'y'>((nnz + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, thr_entry->stream, nblks, nthrs, 0, thr_entry->stream,
...@@ -190,6 +253,7 @@ void SDDMMCoo( ...@@ -190,6 +253,7 @@ void SDDMMCoo(
lhs_off, rhs_off, lhs_off, rhs_off,
lhs_len, rhs_len, len); lhs_len, rhs_len, len);
}); });
}
} }
/*! /*!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment