Unverified Commit f5f7e08e authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Performance][CUDA] Sorting for indices for UVM code path. (#5882)

parent a64ff482
...@@ -14,12 +14,13 @@ namespace impl { ...@@ -14,12 +14,13 @@ namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel( __global__ void IndexSelectSingleKernel(
const DType* array, const IdType* index, const int64_t length, const DType* array, const IdType* index, const int64_t length,
const int64_t arr_len, DType* out) { const int64_t arr_len, DType* out, const int64_t* perm = nullptr) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
assert(index[tx] >= 0 && index[tx] < arr_len); assert(index[tx] >= 0 && index[tx] < arr_len);
out[tx] = array[index[tx]]; const auto out_row = perm ? perm[tx] : tx;
out[out_row] = array[index[tx]];
tx += stride_x; tx += stride_x;
} }
} }
...@@ -27,20 +28,22 @@ __global__ void IndexSelectSingleKernel( ...@@ -27,20 +28,22 @@ __global__ void IndexSelectSingleKernel(
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel( __global__ void IndexSelectMultiKernel(
const DType* const array, const int64_t num_feat, const IdType* const index, const DType* const array, const int64_t num_feat, const IdType* const index,
const int64_t length, const int64_t arr_len, DType* const out) { const int64_t length, const int64_t arr_len, DType* const out,
int64_t out_row = blockIdx.x * blockDim.y + threadIdx.y; const int64_t* perm = nullptr) {
int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t stride = blockDim.y * gridDim.x; const int64_t stride = blockDim.y * gridDim.x;
while (out_row < length) { while (out_row_index < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row_index];
assert(in_row >= 0 && in_row < arr_len); assert(in_row >= 0 && in_row < arr_len);
const auto out_row = perm ? perm[out_row_index] : out_row_index;
while (col < num_feat) { while (col < num_feat) {
out[out_row * num_feat + col] = array[in_row * num_feat + col]; out[out_row * num_feat + col] = array[in_row * num_feat + col];
col += blockDim.x; col += blockDim.x;
} }
out_row += stride; out_row_index += stride;
} }
} }
......
...@@ -63,32 +63,14 @@ __global__ void _COODecodeEdgesKernel( ...@@ -63,32 +63,14 @@ __global__ void _COODecodeEdgesKernel(
} }
} }
template <typename T>
int _NumberOfBits(const T& range) {
if (range <= 1) {
// ranges of 0 or 1 require no bits to store
return 0;
}
int bits = 1;
while (bits < static_cast<int>(sizeof(T) * 8) && (1 << bits) < range) {
++bits;
}
CHECK_EQ((range - 1) >> bits, 0);
CHECK_NE((range - 1) >> (bits - 1), 0);
return bits;
}
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows); const int row_bits = cuda::_NumberOfBits(coo->num_rows);
const int64_t nnz = coo->row->shape[0]; const int64_t nnz = coo->row->shape[0];
if (sort_column) { if (sort_column) {
const int col_bits = _NumberOfBits(coo->num_cols); const int col_bits = cuda::_NumberOfBits(coo->num_cols);
const int num_bits = row_bits + col_bits; const int num_bits = row_bits + col_bits;
const int nt = 256; const int nt = 256;
......
...@@ -38,6 +38,24 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) { ...@@ -38,6 +38,24 @@ inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
return ret; return ret;
} }
template <typename T>
int _NumberOfBits(const T& range) {
if (range <= 1) {
// ranges of 0 or 1 require no bits to store
return 0;
}
int bits = 1;
while (bits < static_cast<int>(sizeof(T) * 8) && (1 << bits) < range) {
++bits;
}
CHECK_EQ((range - 1) >> bits, 0);
CHECK_NE((range - 1) >> (bits - 1), 0);
return bits;
}
/** /**
* @brief Find number of blocks is smaller than nblks and max_nblks * @brief Find number of blocks is smaller than nblks and max_nblks
* on the given axis ('x', 'y' or 'z'). * on the given axis ('x', 'y' or 'z').
......
...@@ -18,7 +18,6 @@ namespace impl { ...@@ -18,7 +18,6 @@ namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
int64_t num_feat = 1; int64_t num_feat = 1;
...@@ -37,12 +36,16 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -37,12 +36,16 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
if (len == 0 || arr_len * num_feat == 0) return ret; if (len == 0 || arr_len * num_feat == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
auto res = Sort(index, cuda::_NumberOfBits(arr_len));
const IdType* idx_data = static_cast<IdType*>(res.first->data);
const int64_t* perm_data = static_cast<int64_t*>(res.second->data);
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len, IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,
arr_len, ret_data); arr_len, ret_data, perm_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2 * num_feat) { while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
...@@ -53,11 +56,11 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -53,11 +56,11 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) { if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat, IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,
idx_data, len, arr_len, ret_data); idx_data, len, arr_len, ret_data, perm_data);
} else { } else {
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectMultiKernelAligned, grid, block, 0, stream, array_data, IndexSelectMultiKernelAligned, grid, block, 0, stream, array_data,
num_feat, idx_data, len, arr_len, ret_data); num_feat, idx_data, len, arr_len, ret_data, perm_data);
} }
} }
return ret; return ret;
......
...@@ -21,25 +21,27 @@ namespace impl { ...@@ -21,25 +21,27 @@ namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernelAligned( __global__ void IndexSelectMultiKernelAligned(
const DType* const array, const int64_t num_feat, const IdType* const index, const DType* const array, const int64_t num_feat, const IdType* const index,
const int64_t length, const int64_t arr_len, DType* const out) { const int64_t length, const int64_t arr_len, DType* const out,
int64_t out_row = blockIdx.x * blockDim.y + threadIdx.y; const int64_t* perm = nullptr) {
int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t stride = blockDim.y * gridDim.x; const int64_t stride = blockDim.y * gridDim.x;
while (out_row < length) { while (out_row_index < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row_index];
assert(in_row >= 0 && in_row < arr_len); assert(in_row >= 0 && in_row < arr_len);
const int64_t idx_offset = const int64_t idx_offset =
((uint64_t)(&array[in_row * num_feat]) % CACHE_LINE_SIZE) / ((uint64_t)(&array[in_row * num_feat]) % CACHE_LINE_SIZE) /
sizeof(DType); sizeof(DType);
col = col - idx_offset; col = col - idx_offset;
const auto out_row = perm ? perm[out_row_index] : out_row_index;
while (col < num_feat) { while (col < num_feat) {
if (col >= 0) if (col >= 0)
out[out_row * num_feat + col] = array[in_row * num_feat + col]; out[out_row * num_feat + col] = array[in_row * num_feat + col];
col += blockDim.x; col += blockDim.x;
} }
out_row += stride; out_row_index += stride;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment