"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b094075cbc8834d63a9fa8ae08bcad3d72a43321"
Unverified Commit aa419895 authored by Ping Gong's avatar Ping Gong Committed by GitHub
Browse files

[Performance] Leverage hashmap to accelerate CSRSliceMatrix<kDGLCUDA, IdType> (#4924)



* Leverage hashmap to accelerate CSRSliceMatrix

* fix lint check

* use `min` in cuda_runtime.ch

* fix hash func

* add some comments and adjust the <grid,block> of the _SegmentMaskColKernel kernel

* set device and stream for thrust::for_each

* use thrust::cuda::par_nosync
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent bf264d00
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* @brief CSR operator CPU implementation * @brief CSR operator CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <numeric> #include <numeric>
#include <unordered_set> #include <unordered_set>
...@@ -423,26 +425,148 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>( ...@@ -423,26 +425,148 @@ template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
///////////////////////////// CSRSliceMatrix ///////////////////////////// ///////////////////////////// CSRSliceMatrix /////////////////////////////
int64_t _UpPower(int64_t numel) {
uint64_t ret = 1 << static_cast<uint64_t>(std::log2(numel) + 1);
return ret;
}
/** /**
* @brief Generate a 0-1 mask for each index whose column is in the provided * @brief Thomas Wang's 32 bit Mix Function.
* set. It also counts the number of masked values per row. * Source link: https://gist.github.com/badboy/6267743
*/
__device__ inline uint32_t _Hash32Shift(uint32_t key) {
key = ~key + (key << 15);
key = key ^ (key >> 12);
key = key + (key << 2);
key = key ^ (key >> 4);
key = key * 2057;
key = key ^ (key >> 16);
return key;
}
/**
* @brief Thomas Wang's 64 bit Mix Function.
* Source link: https://gist.github.com/badboy/6267743
*/
__device__ inline uint64_t _Hash64Shift(uint64_t key) {
key = (~key) + (key << 21);
key = key ^ (key >> 24);
key = (key + (key << 3)) + (key << 8);
key = key ^ (key >> 14);
key = (key + (key << 2)) + (key << 4);
key = key ^ (key >> 28);
key = key + (key << 31);
return key;
}
/**
* @brief A hashmap designed for CSRSliceMatrix, similar in function to set. For
* performance, it can only be created and called in the cuda kernel.
*/ */
template <typename IdType> template <typename IdType>
struct NodeQueryHashmap {
__device__ inline NodeQueryHashmap(IdType* Kptr, size_t numel)
: kptr_(Kptr), capacity_(numel) {}
/**
* @brief Insert a key. It must be called by cuda threads.
*
* @param key The key to be inserted.
*/
__device__ inline void Insert(IdType key) {
uint32_t delta = 1;
uint32_t pos = Hash(key);
IdType prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);
while (prev != key && prev != kEmptyKey_) {
pos = Hash(pos + delta);
delta += 1;
prev = dgl::aten::cuda::AtomicCAS(&kptr_[pos], kEmptyKey_, key);
}
}
/**
* @brief Check whether a key exists within the hashtable. It must be called
* by cuda threads.
*
* @param key The key to check for.
* @return True if the key exists in the hashtable.
*/
__device__ inline bool Query(IdType key) {
uint32_t delta = 1;
uint32_t pos = Hash(key);
while (true) {
if (kptr_[pos] == key) return true;
if (kptr_[pos] == kEmptyKey_) return false;
pos = Hash(pos + delta);
delta += 1;
}
return false;
}
__device__ inline uint32_t Hash(int32_t key) {
return _Hash32Shift(key) & (capacity_ - 1);
}
__device__ inline uint32_t Hash(uint32_t key) {
return _Hash32Shift(key) & (capacity_ - 1);
}
__device__ inline uint32_t Hash(int64_t key) {
return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);
}
__device__ inline uint32_t Hash(uint64_t key) {
return static_cast<uint32_t>(_Hash64Shift(key)) & (capacity_ - 1);
}
IdType kEmptyKey_{-1};
IdType* kptr_;
uint32_t capacity_{0};
};
/**
* @brief Generate a 0-1 mask for each index whose column is in the provided
* hashmap. It also counts the number of masked values per row.
*
* @tparam IdType The ID type used for matrices.
* @tparam WARP_SIZE The number of cuda threads in a cuda warp.
* @tparam BLOCK_WARPS The number of warps in a cuda block.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
*/
template <typename IdType, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void _SegmentMaskColKernel( __global__ void _SegmentMaskColKernel(
const IdType* indptr, const IdType* indices, int64_t num_rows, const IdType* indptr, const IdType* indices, int64_t num_rows,
int64_t num_nnz, const IdType* col, int64_t col_len, IdType* mask, IdType* hashmap_buffer, int64_t buffer_size, IdType* mask, IdType* count) {
IdType* count) { assert(blockDim.x == WARP_SIZE);
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; assert(blockDim.y == BLOCK_WARPS);
const int stride_x = gridDim.x * blockDim.x;
while (tx < num_nnz) { int warp_id = threadIdx.y;
IdType rpos = dgl::cuda::_UpperBound(indptr, num_rows, tx) - 1; int laneid = threadIdx.x;
IdType cur_c = indices[tx]; IdType out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
IdType i = dgl::cuda::_BinarySearch(col, col_len, cur_c); IdType last_row =
if (i < col_len) { min(static_cast<IdType>((blockIdx.x + 1) * TILE_SIZE),
mask[tx] = 1; static_cast<IdType>(num_rows));
cuda::AtomicAdd(count + rpos, IdType(1));
NodeQueryHashmap<IdType> hashmap(hashmap_buffer, buffer_size);
typedef cub::WarpReduce<IdType> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[BLOCK_WARPS];
while (out_row < last_row) {
IdType local_count = 0;
IdType in_row_start = indptr[out_row];
IdType in_row_end = indptr[out_row + 1];
for (int idx = in_row_start + laneid; idx < in_row_end; idx += WARP_SIZE) {
bool is_in = hashmap.Query(indices[idx]);
if (is_in) {
local_count += 1;
mask[idx] = 1;
}
} }
tx += stride_x; IdType reduce_count = WarpReduce(temp_storage[warp_id]).Sum(local_count);
if (laneid == 0) {
count[out_row] = reduce_count;
}
out_row += BLOCK_WARPS;
} }
} }
...@@ -476,26 +600,23 @@ CSRMatrix CSRSliceMatrix( ...@@ -476,26 +600,23 @@ CSRMatrix CSRSliceMatrix(
CUDA_CALL( CUDA_CALL(
cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows))); cudaMemset(count.Ptr<IdType>(), 0, sizeof(IdType) * (csr.num_rows)));
const int64_t nnz_csr = csr.indices->shape[0]; // Generate a NodeQueryHashmap buffer. The key of the hashmap is col.
const int nt = 256; // For performance, the load factor of the hashmap is in (0.25, 0.5);
// Because num_cols is usually less than 1 Million (on GPU), the
// In general ``cols'' array is sorted. But it is not guaranteed. // memory overhead is not significant (less than 31MB) at a low load factor.
// Hence checking and sorting array first. Sorting is not in place. int64_t buffer_size = _UpPower(new_ncols) * 2;
auto device = runtime::DeviceAPI::Get(ctx); IdArray hashmap_buffer = Full(-1, buffer_size, nbits, ctx);
auto cols_size = cols->shape[0];
using it = thrust::counting_iterator<int64_t>;
IdArray sorted_array = NewIdArray(cols->shape[0], ctx, cols->dtype.bits); runtime::CUDAWorkspaceAllocator allocator(ctx);
auto ptr_sorted_cols = sorted_array.Ptr<IdType>(); const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
auto ptr_cols = cols.Ptr<IdType>(); thrust::for_each(
size_t workspace_size = 0; exec_policy, it(0), it(new_ncols),
CUDA_CALL(cub::DeviceRadixSort::SortKeys( [key = cols.Ptr<IdType>(), buffer = hashmap_buffer.Ptr<IdType>(),
nullptr, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0, buffer_size] __device__(int64_t i) {
sizeof(IdType) * 8, stream)); NodeQueryHashmap<IdType> hashmap(buffer, buffer_size);
void* workspace = device->AllocWorkspace(ctx, workspace_size); hashmap.Insert(key[i]);
CUDA_CALL(cub::DeviceRadixSort::SortKeys( });
workspace, workspace_size, ptr_cols, ptr_sorted_cols, cols->shape[0], 0,
sizeof(IdType) * 8, stream));
device->FreeWorkspace(ctx, workspace);
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
...@@ -507,10 +628,19 @@ CSRMatrix CSRSliceMatrix( ...@@ -507,10 +628,19 @@ CSRMatrix CSRSliceMatrix(
} }
// Execute SegmentMaskColKernel // Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt; const int64_t num_rows = csr.num_rows;
constexpr int WARP_SIZE = 32;
// With a simple fine-tuning, TILE_SIZE=16 gives a good performance.
constexpr int TILE_SIZE = 16;
constexpr int BLOCK_WARPS = CUDA_MAX_NUM_THREADS / WARP_SIZE;
IdType nb =
dgl::cuda::FindNumBlocks<'x'>((num_rows + TILE_SIZE - 1) / TILE_SIZE);
const dim3 nthrs(WARP_SIZE, BLOCK_WARPS);
const dim3 nblks(nb);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_SegmentMaskColKernel, nb, nt, 0, stream, indptr_data, indices_data, (_SegmentMaskColKernel<IdType, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>), nblks,
csr.num_rows, nnz_csr, ptr_sorted_cols, cols_size, mask.Ptr<IdType>(), nthrs, 0, stream, indptr_data, indices_data, num_rows,
hashmap_buffer.Ptr<IdType>(), buffer_size, mask.Ptr<IdType>(),
count.Ptr<IdType>()); count.Ptr<IdType>());
IdArray idx = AsNumBits(NonZero(mask), nbits); IdArray idx = AsNumBits(NonZero(mask), nbits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment