Unverified Commit a28f1f9a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt][CUDA] Migrate utils and cuda common from dgl (#6631)

parent 8df64670
/**
* Copyright (c) 2017-2023 by Contributors
* @file cuda/common.h
* @brief Common utilities for CUDA
*/
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
namespace graphbolt {
namespace cuda {
template <typename T>
inline bool is_zero(T size) {
return size == 0;
}
template <>
inline bool is_zero<dim3>(dim3 size) {
return size.x == 0 || size.y == 0 || size.z == 0;
}
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \
if (!graphbolt::cuda::is_zero((nblks)) && \
!graphbolt::cuda::is_zero((nthrs))) { \
(kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} \
}
} // namespace cuda
} // namespace graphbolt
#endif // GRAPHBOLT_CUDA_COMMON_H_
...@@ -3,12 +3,13 @@ ...@@ -3,12 +3,13 @@
* @file cuda/index_select_impl.cu * @file cuda/index_select_impl.cu
* @brief Index select operator implementation on CUDA. * @brief Index select operator implementation on CUDA.
*/ */
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAStream.h>
#include <torch/script.h> #include <torch/script.h>
#include <numeric> #include <numeric>
#include "../index_select.h" #include "../index_select.h"
#include "./common.h"
#include "./utils.h" #include "./utils.h"
namespace graphbolt { namespace graphbolt {
...@@ -115,15 +116,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -115,15 +116,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>(); const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>();
const int64_t* permutation_ptr = permutation.data_ptr<int64_t>(); const int64_t* permutation_ptr = permutation.data_ptr<int64_t>();
cudaStream_t stream = 0; cudaStream_t stream = torch::cuda::getDefaultCUDAStream();
if (feature_size == 1) { if (feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads. // Use a single thread to process each output row to avoid wasting threads.
const int num_threads = cuda::FindNumThreads(return_len); const int num_threads = cuda::FindNumThreads(return_len);
const int num_blocks = (return_len + num_threads - 1) / num_threads; const int num_blocks = (return_len + num_threads - 1) / num_threads;
IndexSelectSingleKernel<<<num_blocks, num_threads, 0, stream>>>( CUDA_KERNEL_CALL(
input_ptr, input_len, index_sorted_ptr, return_len, ret_ptr, IndexSelectSingleKernel, num_blocks, num_threads, 0, stream, input_ptr,
permutation_ptr); input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else { } else {
dim3 block(512, 1); dim3 block(512, 1);
while (static_cast<int64_t>(block.x) >= 2 * feature_size) { while (static_cast<int64_t>(block.x) >= 2 * feature_size) {
...@@ -134,17 +135,17 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -134,17 +135,17 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
if (feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) { if (feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
// When feature size is smaller than GPU cache line size, use unaligned // When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient. // version for less SM usage, which is more resource efficient.
IndexSelectMultiKernel<<<grid, block, 0, stream>>>( CUDA_KERNEL_CALL(
input_ptr, input_len, feature_size, index_sorted_ptr, return_len, IndexSelectMultiKernel, grid, block, 0, stream, input_ptr, input_len,
ret_ptr, permutation_ptr); feature_size, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else { } else {
// Use aligned version to improve the memory access pattern. // Use aligned version to improve the memory access pattern.
IndexSelectMultiKernelAligned<<<grid, block, 0, stream>>>( CUDA_KERNEL_CALL(
input_ptr, input_len, feature_size, index_sorted_ptr, return_len, IndexSelectMultiKernelAligned, grid, block, 0, stream, input_ptr,
ret_ptr, permutation_ptr); input_len, feature_size, index_sorted_ptr, return_len, ret_ptr,
permutation_ptr);
} }
} }
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto return_shape = std::vector<int64_t>({return_len}); auto return_shape = std::vector<int64_t>({return_len});
return_shape.insert( return_shape.insert(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment