Unverified Commit 83115794 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Performance][CUDA] Faster CSRToCOO (#5648)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent dc06060b
......@@ -50,6 +50,7 @@ if(USE_CUDA)
message(STATUS "Use external CUB/Thrust library for a consistent API and performance.")
cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/thrust")
cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/thrust/dependencies/cub")
cuda_include_directories(BEFORE "${CMAKE_SOURCE_DIR}/third_party/thrust/dependencies/libcudacxx/include")
endif(USE_CUDA)
# initial variables
......
......@@ -4,8 +4,12 @@
* @brief CSR2COO
*/
#include <dgl/array.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl {
......@@ -45,33 +49,27 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
csr.num_rows, csr.num_cols, row, indices, data, true, csr.sorted);
}
/**
* @brief Repeat elements
* @param val Value to repeat
* @param repeats Number of repeats for each value
* @param pos The position of the output buffer to write the value.
* @param out Output buffer.
* @param length Number of values
*
* For example:
* val = [3, 0, 1]
* repeats = [1, 0, 2]
* pos = [0, 1, 1] # write to output buffer position 0, 1, 1
* then,
* out = [3, 1, 1]
*/
template <typename DType, typename IdType>
__global__ void _RepeatKernel(
const DType* val, const IdType* pos, DType* out, int64_t n_row,
int64_t length) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
IdType i = dgl::cuda::_UpperBound(pos, n_row, tx) - 1;
out[tx] = val[i];
tx += stride_x;
struct RepeatIndex {
template <typename IdType>
__host__ __device__ auto operator()(IdType i) {
return thrust::make_constant_iterator(i);
}
}
};
template <typename IdType>
struct OutputBufferIndexer {
const IdType* indptr;
IdType* buffer;
__host__ __device__ auto operator()(IdType i) { return buffer + indptr[i]; }
};
template <typename IdType>
struct AdjacentDifference {
const IdType* indptr;
__host__ __device__ auto operator()(IdType i) {
return indptr[i + 1] - indptr[i];
}
};
template <>
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
......@@ -80,14 +78,33 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0];
const auto nbits = csr.indptr->dtype.bits;
IdArray rowids = Range(0, csr.num_rows, nbits, ctx);
IdArray ret_row = NewIdArray(nnz, ctx, nbits);
const int nt = 256;
const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(
_RepeatKernel, nb, nt, 0, stream, rowids.Ptr<int64_t>(),
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(), csr.num_rows, nnz);
runtime::CUDAWorkspaceAllocator allocator(csr.indptr->ctx);
thrust::counting_iterator<int64_t> iota(0);
auto input_buffer = thrust::make_transform_iterator(iota, RepeatIndex{});
auto output_buffer = thrust::make_transform_iterator(
iota, OutputBufferIndexer<int64_t>{
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>()});
auto buffer_sizes = thrust::make_transform_iterator(
iota, AdjacentDifference<int64_t>{csr.indptr.Ptr<int64_t>()});
constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();
for (int64_t i = 0; i < csr.num_rows; i += max_copy_at_once) {
std::size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceCopy::Batched(
nullptr, temp_storage_bytes, input_buffer + i, output_buffer + i,
buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),
stream));
auto temp = allocator.alloc_unique<char>(temp_storage_bytes);
CUDA_CALL(cub::DeviceCopy::Batched(
temp.get(), temp_storage_bytes, input_buffer + i, output_buffer + i,
buffer_sizes + i, std::min(csr.num_rows - i, max_copy_at_once),
stream));
}
return COOMatrix(
csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
......
Subproject commit 6a3078c64cab0e2f276340fa5dcafa0d758ed890
Subproject commit 02931a309bee769853088b79b4e3ab1c0bd2336c
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment