Unverified Commit 66a54555 authored by ayasar70's avatar ayasar70 Committed by GitHub
Browse files

[Performance][GPU] Improve csr2coo.cu:_RepeatKernel() for more robust GPU usage (#3537)



* Based on issue #3436. Improving _SegmentCopyKernel s GPU utilization by switching to nonzero based thread assignment

* fixing lint issues

* Update cub for cuda 11.5 compatibility (#3468)

* fixing type mismatch

* tx guaranteed to be smaller than nnz. Hence removing last check

* minor: updating comment

* adding three unit tests for csr slice method to cover some corner cases

* working on repeat

* updating repeat kernel

* removing unnecessary parameter

* cleaning commented line

* cleaning time measures

* cleaning time measurement lines
Co-authored-by: default avatarAbdurrahman Yasar <ayasar@nvidia.com>
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 44f0b5fe
......@@ -64,17 +64,21 @@ COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
*/
template <typename DType, typename IdType>
__global__ void _RepeatKernel(
const DType* val, const IdType* repeats, const IdType* pos,
DType* out, int64_t length) {
const DType* val, const IdType* pos,
DType* out, int64_t n_row, int64_t length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
IdType off = pos[tx];
const IdType rep = repeats[tx];
const DType v = val[tx];
for (IdType i = 0; i < rep; ++i) {
out[off + i] = v;
IdType l = 0, r = n_row, m = 0;
while (l < r) {
m = l + (r-l)/2;
if (tx >= pos[m]) {
l = m+1;
} else {
r = m;
}
}
out[tx] = val[l-1];
tx += stride_x;
}
}
......@@ -86,16 +90,15 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
const auto nbits = csr.indptr->dtype.bits;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
IdArray rowids = Range(0, csr.num_rows, nbits, ctx);
IdArray row_nnz = CSRGetRowNNZ(csr, rowids);
IdArray ret_row = NewIdArray(nnz, ctx, nbits);
const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt;
const int nt = 256;
const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_RepeatKernel,
nb, nt, 0, thr_entry->stream,
rowids.Ptr<int64_t>(), row_nnz.Ptr<int64_t>(),
rowids.Ptr<int64_t>(),
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
csr.num_rows);
csr.num_rows, nnz);
return COOMatrix(csr.num_rows, csr.num_cols,
ret_row, csr.indices, csr.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment