Unverified Commit 7b9afbfa authored by ayasar70's avatar ayasar70 Committed by GitHub
Browse files

[Performance][GPU] Improving _SegmentMaskColKernel (#3745)



* Based on issue #3436. Improving _SegmentCopyKernel s GPU utilization by switching to nonzero based thread assignment

* fixing lint issues

* Update cub for cuda 11.5 compatibility (#3468)

* fixing type mismatch

* tx guaranteed to be smaller than nnz. Hence removing last check

* minor: updating comment

* adding three unit tests for csr slice method to cover some corner cases

* timing repeatkernel

* clean

* clean

* clean

* updating _SegmentMaskColKernel

* Working on requests: removing sorted array check and adding comments to utility functions

* fixing lint issue
Co-authored-by: default avatarAbdurrahman Yasar <ayasar@nvidia.com>
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent f247d29f
...@@ -66,19 +66,11 @@ template <typename DType, typename IdType> ...@@ -66,19 +66,11 @@ template <typename DType, typename IdType>
__global__ void _RepeatKernel( __global__ void _RepeatKernel(
const DType* val, const IdType* pos, const DType* val, const IdType* pos,
DType* out, int64_t n_row, int64_t length) { DType* out, int64_t n_row, int64_t length) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
IdType l = 0, r = n_row, m = 0; IdType i = dgl::cuda::_UpperBound(pos, n_row, tx) - 1;
while (l < r) { out[tx] = val[i];
m = l + (r-l)/2;
if (tx >= pos[m]) {
l = m+1;
} else {
r = m;
}
}
out[tx] = val[l-1];
tx += stride_x; tx += stride_x;
} }
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <numeric> #include <numeric>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "./atomic.cuh"
#include "./dgl_cub.cuh"
namespace dgl { namespace dgl {
...@@ -30,7 +32,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -30,7 +32,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8); IdArray out = aten::NewIdArray(1, ctx, sizeof(IdType) * 8);
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
1, 1, 0, thr_entry->stream, 1, 1, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rows.Ptr<IdType>(), cols.Ptr<IdType>(),
...@@ -54,11 +56,11 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -54,11 +56,11 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const int nt = cuda::FindNumThreads(rstlen); const int nt = dgl::cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr; const IdType* data = nullptr;
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
row.Ptr<IdType>(), col.Ptr<IdType>(), row.Ptr<IdType>(), col.Ptr<IdType>(),
...@@ -103,13 +105,13 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -103,13 +105,13 @@ bool CSRHasDuplicate(CSRMatrix csr) {
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory but should
// be fine. // be fine.
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows)); int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
const int nt = cuda::FindNumThreads(csr.num_rows); const int nt = dgl::cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentHasNoDuplicate, CUDA_KERNEL_CALL(_SegmentHasNoDuplicate,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
csr.num_rows, flags); csr.num_rows, flags);
bool ret = cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = dgl::cuda::AllTrue(flags, csr.num_rows, ctx);
device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, flags);
return !ret; return !ret;
} }
...@@ -152,7 +154,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -152,7 +154,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data); IdType* rst_data = static_cast<IdType*>(rst->data);
const int nt = cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_CSRGetRowNNZKernel, CUDA_KERNEL_CALL(_CSRGetRowNNZKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
...@@ -233,19 +235,7 @@ __global__ void _SegmentCopyKernel( ...@@ -233,19 +235,7 @@ __global__ void _SegmentCopyKernel(
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
// find upper bound for tx using binary search. IdType rpos = dgl::cuda::_UpperBound(out_indptr, n_row, tx) - 1;
// out_indptr has already a prefix sum. n_row = size(out_indptr)-1
IdType l = 0, r = n_row, m = 0;
while (l < r) {
m = l + (r-l)/2;
if (tx >= out_indptr[m]) {
l = m+1;
} else {
r = m;
}
}
IdType rpos = l-1;
IdType rofs = tx - out_indptr[rpos]; IdType rofs = tx - out_indptr[rpos];
const IdType u = row[rpos]; const IdType u = row[rpos];
out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs; out_data[tx] = data? data[indptr[u]+rofs] : indptr[u]+rofs;
...@@ -372,7 +362,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -372,7 +362,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
// Generate a 0-1 mask for matched (row, col) positions. // Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx); IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskKernel, CUDA_KERNEL_CALL(_SegmentMaskKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
...@@ -388,7 +378,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co ...@@ -388,7 +378,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
// Search for row index // Search for row index
IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits); IdArray ret_row = NewIdArray(idx->shape[0], ctx, nbits);
const int nt2 = cuda::FindNumThreads(idx->shape[0]); const int nt2 = dgl::cuda::FindNumThreads(idx->shape[0]);
const int nb2 = (idx->shape[0] + nt - 1) / nt; const int nb2 = (idx->shape[0] + nt - 1) / nt;
CUDA_KERNEL_CALL(_SortedSearchKernel, CUDA_KERNEL_CALL(_SortedSearchKernel,
nb2, nt2, 0, thr_entry->stream, nb2, nt2, 0, thr_entry->stream,
...@@ -415,25 +405,19 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>( ...@@ -415,25 +405,19 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>(
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentMaskColKernel( __global__ void _SegmentMaskColKernel(
const IdType* indptr, const IdType* indices, int64_t num_rows, const IdType* indptr, const IdType* indices, int64_t num_rows, int64_t num_nnz,
const IdType* col, int64_t col_len, const IdType* col, int64_t col_len,
IdType* mask, IdType* count) { IdType* mask, IdType* count) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
// TODO(minjie): consider putting the col array in shared memory. while (tx < num_nnz) {
while (tx < num_rows) { IdType rpos = dgl::cuda::_UpperBound(indptr, num_rows, tx) - 1;
IdType cnt = 0; IdType cur_c = indices[tx];
for (IdType i = indptr[tx]; i < indptr[tx + 1]; ++i) { IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c);
const IdType cur_c = indices[i]; if (i < col_len) {
for (int64_t j = 0; j < col_len; ++j) { mask[tx] = 1;
if (cur_c == col[j]) { cuda::AtomicAdd(count+rpos, IdType(1));
mask[i] = 1;
++cnt;
break;
}
}
} }
count[tx] = cnt;
tx += stride_x; tx += stride_x;
} }
} }
...@@ -464,12 +448,33 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -464,12 +448,33 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx); IdArray mask = Full(0, csr.indices->shape[0], nbits, ctx);
// A count for how many masked values per row. // A count for how many masked values per row.
IdArray count = NewIdArray(csr.num_rows, ctx, nbits); IdArray count = NewIdArray(csr.num_rows, ctx, nbits);
const int nt = cuda::FindNumThreads(csr.num_rows); CUDA_CALL(cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));
const int nb = (csr.num_rows + nt - 1) / nt;
const int64_t nnz_csr = csr.indices->shape[0];
const int nt = 256;
// In general ``cols'' array is sorted. But it is not guaranteed.
// Hence checking and sorting array first. Sorting is not in place.
auto device = runtime::DeviceAPI::Get(ctx);
auto cols_size = cols->shape[0];
IdArray sorted_array = NewIdArray(cols->shape[0], ctx, cols->dtype.bits);
auto ptr_sorted_cols = sorted_array.Ptr<IdType>();
auto ptr_cols = cols.Ptr<IdType>();
size_t workspace_size = 0;
cub::DeviceRadixSort::SortKeys(
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]);
void *workspace = device->AllocWorkspace(ctx, workspace_size);
cub::DeviceRadixSort::SortKeys(
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0]);
device->FreeWorkspace(ctx, workspace);
// Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskColKernel, CUDA_KERNEL_CALL(_SegmentMaskColKernel,
nb, nt, 0, thr_entry->stream, nb, nt, 0, thr_entry->stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, nnz_csr,
cols.Ptr<IdType>(), cols->shape[0], ptr_sorted_cols, cols_size,
mask.Ptr<IdType>(), count.Ptr<IdType>()); mask.Ptr<IdType>(), count.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits); IdArray idx = AsNumBits(NonZero(mask), nbits);
......
...@@ -192,6 +192,58 @@ inline DType GetCUDAScalar( ...@@ -192,6 +192,58 @@ inline DType GetCUDAScalar(
return result; return result;
} }
/*!
* \brief Given a sorted array and a value this function returns the index
* of the first element which compares greater than value.
*
* This function assumes 0-based index
* @param A: ascending sorted array
* @param n: size of the A
* @param x: value to search in A
* @return index, i, of the first element st. A[i]>x. If x>=A[n-1] returns n.
* if x<A[0] then it returns 0.
*/
template <typename IdType>
__device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) {
IdType l = 0, r = n, m = 0;
while (l < r) {
m = l + (r-l)/2;
if (x >= A[m]) {
l = m+1;
} else {
r = m;
}
}
return l;
}
/*!
* \brief Given a sorted array and a value this function returns the index
* of the element who is equal to val. If not exist returns n+1
*
* This function assumes 0-based index
* @param A: ascending sorted array
* @param n: size of the A
* @param x: value to search in A
* @return index, i, st. A[i]==x. If such an index not exists returns 'n'.
*/
template <typename IdType>
__device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) {
IdType l = 0, r = n-1, m = 0;
while (l <= r) {
m = l + (r-l)/2;
if (A[m] == x) {
return m;
}
if (A[m] < x) {
l = m+1;
} else {
r = m-1;
}
}
return n; // not found
}
} // namespace cuda } // namespace cuda
} // namespace dgl } // namespace dgl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment