Unverified Commit 3795a006 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refactor codebase with `CUB_CALL` macro (#6870)

parent f86212ed
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <ATen/cuda/CUDAEvent.h> #include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/script.h> #include <torch/script.h>
...@@ -82,15 +83,34 @@ inline bool is_zero<dim3>(dim3 size) { ...@@ -82,15 +83,34 @@ inline bool is_zero<dim3>(dim3 size) {
#define CUDA_CALL(func) C10_CUDA_CHECK((func)) #define CUDA_CALL(func) C10_CUDA_CHECK((func))
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \ #define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, ...) \
{ \ { \
if (!graphbolt::cuda::is_zero((nblks)) && \ if (!graphbolt::cuda::is_zero((nblks)) && \
!graphbolt::cuda::is_zero((nthrs))) { \ !graphbolt::cuda::is_zero((nthrs))) { \
(kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \ auto stream = graphbolt::cuda::GetCurrentStream(); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \ (kernel)<<<(nblks), (nthrs), (shmem), stream>>>(__VA_ARGS__); \
} \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} \
} }
#define CUB_CALL(fn, ...) \
{ \
auto allocator = graphbolt::cuda::GetAllocator(); \
auto stream = graphbolt::cuda::GetCurrentStream(); \
size_t workspace_size = 0; \
CUDA_CALL(cub::fn(nullptr, workspace_size, __VA_ARGS__, stream)); \
auto workspace = allocator.AllocateStorage<char>(workspace_size); \
CUDA_CALL(cub::fn(workspace.get(), workspace_size, __VA_ARGS__, stream)); \
}
#define THRUST_CALL(fn, ...) \
[&] { \
auto allocator = graphbolt::cuda::GetAllocator(); \
auto stream = graphbolt::cuda::GetCurrentStream(); \
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream); \
return thrust::fn(exec_policy, __VA_ARGS__); \
}()
/** /**
* @brief This class is designed to handle the copy operation of a single * @brief This class is designed to handle the copy operation of a single
* scalar_t item from a given CUDA device pointer. Later, if the object is cast * scalar_t item from a given CUDA device pointer. Later, if the object is cast
......
...@@ -39,8 +39,6 @@ struct AdjacentDifference { ...@@ -39,8 +39,6 @@ struct AdjacentDifference {
}; };
torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) { torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) {
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto num_rows = indptr.size(0) - 1; const auto num_rows = indptr.size(0) - 1;
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
...@@ -69,19 +67,9 @@ torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) { ...@@ -69,19 +67,9 @@ torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) {
constexpr int64_t max_copy_at_once = constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max(); std::numeric_limits<int32_t>::max();
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
std::size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceCopy::Batched( DeviceCopy::Batched, input_buffer + i, output_buffer + i,
nullptr, tmp_storage_size, input_buffer + i, buffer_sizes + i, std::min(num_rows - i, max_copy_at_once));
output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once), stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceCopy::Batched(
tmp_storage.get(), tmp_storage_size, input_buffer + i,
output_buffer + i, buffer_sizes + i,
std::min(num_rows - i, max_copy_at_once), stream));
} }
})); }));
return csr_rows; return csr_rows;
......
...@@ -12,21 +12,14 @@ namespace graphbolt { ...@@ -12,21 +12,14 @@ namespace graphbolt {
namespace ops { namespace ops {
torch::Tensor ExclusiveCumSum(torch::Tensor input) { torch::Tensor ExclusiveCumSum(torch::Tensor input) {
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
auto result = torch::empty_like(input); auto result = torch::empty_like(input);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "ExclusiveCumSum", ([&] {
input.scalar_type(), "ExclusiveCumSum", ([&] { CUB_CALL(
size_t tmp_storage_size = 0; DeviceScan::ExclusiveSum,
cub::DeviceScan::ExclusiveSum( input.data_ptr<scalar_t>(),
nullptr, tmp_storage_size, input.data_ptr<scalar_t>(), result.data_ptr<scalar_t>(), input.size(0));
result.data_ptr<scalar_t>(), input.size(0), stream); }));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceScan::ExclusiveSum(
tmp_storage.get(), tmp_storage_size, input.data_ptr<scalar_t>(),
result.data_ptr<scalar_t>(), input.size(0), stream);
}));
return result; return result;
} }
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
* @brief Index select csc operator implementation on CUDA. * @brief Index select csc operator implementation on CUDA.
*/ */
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <numeric> #include <numeric>
...@@ -88,7 +87,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -88,7 +87,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::Tensor indices, const int64_t num_nodes, torch::Tensor indices, const int64_t num_nodes,
const indptr_t* const in_degree, const indptr_t* const sliced_indptr, const indptr_t* const in_degree, const indptr_t* const sliced_indptr,
const int64_t* const perm, torch::TensorOptions nodes_options, const int64_t* const perm, torch::TensorOptions nodes_options,
torch::ScalarType indptr_scalar_type, cudaStream_t stream) { torch::ScalarType indptr_scalar_type) {
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
...@@ -109,14 +108,9 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -109,14 +108,9 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get()); output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get());
thrust::tuple<indptr_t, indptr_t> zero_value{}; thrust::tuple<indptr_t, indptr_t> zero_value{};
// Compute the prefix sum over actual and modified indegrees. // Compute the prefix sum over actual and modified indegrees.
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceScan::ExclusiveScan( DeviceScan::ExclusiveScan, modified_in_degree, output_indptr_pair,
nullptr, tmp_storage_size, modified_in_degree, output_indptr_pair, PairSum{}, zero_value, num_nodes + 1);
PairSum{}, zero_value, num_nodes + 1, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceScan::ExclusiveScan(
tmp_storage.get(), tmp_storage_size, modified_in_degree,
output_indptr_pair, PairSum{}, zero_value, num_nodes + 1, stream));
} }
// Copy the actual total number of edges. // Copy the actual total number of edges.
...@@ -138,7 +132,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -138,7 +132,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
// Perform the actual copying, of the indices array into // Perform the actual copying, of the indices array into
// output_indices in an aligned manner. // output_indices in an aligned manner.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_CopyIndicesAlignedKernel, grid, block, 0, stream, _CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned), num_nodes, sliced_indptr, static_cast<indptr_t>(edge_count_aligned), num_nodes, sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get(), output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get(),
reinterpret_cast<indices_t*>(indices.data_ptr()), reinterpret_cast<indices_t*>(indices.data_ptr()),
...@@ -151,7 +145,6 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( ...@@ -151,7 +145,6 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
// Sorting nodes so that accesses over PCI-e are more regular. // Sorting nodes so that accesses over PCI-e are more regular.
const auto sorted_idx = const auto sorted_idx =
Sort(nodes, cuda::NumberOfBits(indptr.size(0) - 1)).second; Sort(nodes, cuda::NumberOfBits(indptr.size(0) - 1)).second;
auto stream = cuda::GetCurrentStream();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
...@@ -167,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( ...@@ -167,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>( return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
indices, num_nodes, in_degree, sliced_indptr, indices, num_nodes, in_degree, sliced_indptr,
sorted_idx.data_ptr<int64_t>(), nodes.options(), sorted_idx.data_ptr<int64_t>(), nodes.options(),
indptr.scalar_type(), stream); indptr.scalar_type());
})); }));
})); }));
} }
...@@ -191,9 +184,7 @@ template <typename indptr_t, typename indices_t> ...@@ -191,9 +184,7 @@ template <typename indptr_t, typename indices_t>
void IndexSelectCSCCopyIndices( void IndexSelectCSCCopyIndices(
const int64_t num_nodes, indices_t* const indices, const int64_t num_nodes, indices_t* const indices,
indptr_t* const sliced_indptr, const indptr_t* const in_degree, indptr_t* const sliced_indptr, const indptr_t* const in_degree,
indptr_t* const output_indptr, indices_t* const output_indices, indptr_t* const output_indptr, indices_t* const output_indices) {
cudaStream_t stream) {
auto allocator = cuda::GetAllocator();
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
auto input_buffer_it = thrust::make_transform_iterator( auto input_buffer_it = thrust::make_transform_iterator(
...@@ -206,21 +197,14 @@ void IndexSelectCSCCopyIndices( ...@@ -206,21 +197,14 @@ void IndexSelectCSCCopyIndices(
// Performs the copy from indices into output_indices. // Performs the copy from indices into output_indices.
for (int64_t i = 0; i < num_nodes; i += max_copy_at_once) { for (int64_t i = 0; i < num_nodes; i += max_copy_at_once) {
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceMemcpy::Batched( DeviceMemcpy::Batched, input_buffer_it + i, output_buffer_it + i,
nullptr, tmp_storage_size, input_buffer_it + i, output_buffer_it + i, buffer_sizes + i, std::min(num_nodes - i, max_copy_at_once));
buffer_sizes + i, std::min(num_nodes - i, max_copy_at_once), stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceMemcpy::Batched(
tmp_storage.get(), tmp_storage_size, input_buffer_it + i,
output_buffer_it + i, buffer_sizes + i,
std::min(num_nodes - i, max_copy_at_once), stream));
} }
} }
std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
auto stream = cuda::GetCurrentStream();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
return AT_DISPATCH_INTEGRAL_TYPES( return AT_DISPATCH_INTEGRAL_TYPES(
...@@ -234,17 +218,10 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl( ...@@ -234,17 +218,10 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
torch::Tensor output_indptr = torch::empty( torch::Tensor output_indptr = torch::empty(
num_nodes + 1, nodes.options().dtype(indptr.scalar_type())); num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
{ // Compute the output indptr, output_indptr. // Compute the output indptr, output_indptr.
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceScan::ExclusiveSum( DeviceScan::ExclusiveSum, in_degree,
nullptr, tmp_storage_size, in_degree, output_indptr.data_ptr<indptr_t>(), num_nodes + 1);
output_indptr.data_ptr<indptr_t>(), num_nodes + 1, stream));
auto allocator = cuda::GetAllocator();
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
tmp_storage.get(), tmp_storage_size, in_degree,
output_indptr.data_ptr<indptr_t>(), num_nodes + 1, stream));
}
// Number of edges being copied. // Number of edges being copied.
auto edge_count = auto edge_count =
...@@ -259,8 +236,7 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl( ...@@ -259,8 +236,7 @@ std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
IndexSelectCSCCopyIndices<indptr_t, indices_t>( IndexSelectCSCCopyIndices<indptr_t, indices_t>(
num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()), num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()),
sliced_indptr, in_degree, output_indptr.data_ptr<indptr_t>(), sliced_indptr, in_degree, output_indptr.data_ptr<indptr_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), reinterpret_cast<indices_t*>(output_indices.data_ptr()));
stream);
})); }));
return std::make_tuple(output_indptr, output_indices); return std::make_tuple(output_indptr, output_indices);
})); }));
......
...@@ -5,13 +5,8 @@ ...@@ -5,13 +5,8 @@
* @brief Index select operator implementation on CUDA. * @brief Index select operator implementation on CUDA.
*/ */
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <cub/cub.cuh>
#include <numeric> #include <numeric>
#include "./common.h" #include "./common.h"
...@@ -124,14 +119,12 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -124,14 +119,12 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>(); const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>();
const int64_t* permutation_ptr = permutation.data_ptr<int64_t>(); const int64_t* permutation_ptr = permutation.data_ptr<int64_t>();
auto stream = cuda::GetCurrentStream();
if (aligned_feature_size == 1) { if (aligned_feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads. // Use a single thread to process each output row to avoid wasting threads.
const int num_threads = cuda::FindNumThreads(return_len); const int num_threads = cuda::FindNumThreads(return_len);
const int num_blocks = (return_len + num_threads - 1) / num_threads; const int num_blocks = (return_len + num_threads - 1) / num_threads;
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectSingleKernel, num_blocks, num_threads, 0, stream, input_ptr, IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr); input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else { } else {
dim3 block(512, 1); dim3 block(512, 1);
...@@ -144,15 +137,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -144,15 +137,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
// When feature size is smaller than GPU cache line size, use unaligned // When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient. // version for less SM usage, which is more resource efficient.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectMultiKernel, grid, block, 0, stream, input_ptr, input_len, IndexSelectMultiKernel, grid, block, 0, input_ptr, input_len,
aligned_feature_size, index_sorted_ptr, return_len, ret_ptr, aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,
permutation_ptr); permutation_ptr);
} else { } else {
// Use aligned version to improve the memory access pattern. // Use aligned version to improve the memory access pattern.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectMultiKernelAligned, grid, block, 0, stream, input_ptr, IndexSelectMultiKernelAligned, grid, block, 0, input_ptr, input_len,
input_len, aligned_feature_size, index_sorted_ptr, return_len, aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,
ret_ptr, permutation_ptr); permutation_ptr);
} }
} }
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h> #include <graphbolt/cuda_sampling_ops.h>
#include <cub/cub.cuh>
#include "./common.h" #include "./common.h"
namespace graphbolt { namespace graphbolt {
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <cub/cub.cuh>
#include "./common.h" #include "./common.h"
namespace graphbolt { namespace graphbolt {
...@@ -16,15 +14,12 @@ namespace ops { ...@@ -16,15 +14,12 @@ namespace ops {
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) { torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
auto sorted_test_elements = Sort<false>(test_elements); auto sorted_test_elements = Sort<false>(test_elements);
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
auto result = torch::empty_like(elements, torch::kBool); auto result = torch::empty_like(elements, torch::kBool);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
elements.scalar_type(), "IsInOperation", ([&] { elements.scalar_type(), "IsInOperation", ([&] {
thrust::binary_search( THRUST_CALL(
exec_policy, sorted_test_elements.data_ptr<scalar_t>(), binary_search, sorted_test_elements.data_ptr<scalar_t>(),
sorted_test_elements.data_ptr<scalar_t>() + sorted_test_elements.data_ptr<scalar_t>() +
sorted_test_elements.size(0), sorted_test_elements.size(0),
elements.data_ptr<scalar_t>(), elements.data_ptr<scalar_t>(),
......
...@@ -5,12 +5,10 @@ ...@@ -5,12 +5,10 @@
* @brief Index select operator implementation on CUDA. * @brief Index select operator implementation on CUDA.
*/ */
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h> #include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h> #include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h> #include <thrust/iterator/transform_output_iterator.h>
...@@ -18,7 +16,6 @@ ...@@ -18,7 +16,6 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cuda/std/tuple>
#include <limits> #include <limits>
#include <numeric> #include <numeric>
#include <type_traits> #include <type_traits>
...@@ -142,7 +139,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -142,7 +139,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// are all resident on the GPU. If not, it is better to first extract them // are all resident on the GPU. If not, it is better to first extract them
// before calling this function. // before calling this function.
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
const auto stream = cuda::GetCurrentStream();
auto num_rows = nodes.size(0); auto num_rows = nodes.size(0);
auto fanouts_pinned = torch::empty( auto fanouts_pinned = torch::empty(
fanouts.size(), fanouts.size(),
...@@ -156,7 +152,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -156,7 +152,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size()); auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
CUDA_CALL(cudaMemcpyAsync( CUDA_CALL(cudaMemcpyAsync(
fanouts_device.get(), fanouts_pinned_ptr, fanouts_device.get(), fanouts_pinned_ptr,
sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice, stream)); sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
cuda::GetCurrentStream()));
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
...@@ -185,14 +182,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -185,14 +182,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true)); c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsInDegree", ([&] { indptr.scalar_type(), "SampleNeighborsInDegree", ([&] {
size_t tmp_storage_size = 0; CUB_CALL(
cub::DeviceReduce::Max( DeviceReduce::Max, in_degree.data_ptr<index_t>(),
nullptr, tmp_storage_size, in_degree.data_ptr<index_t>(), max_in_degree.data_ptr<index_t>(), num_rows);
max_in_degree.data_ptr<index_t>(), num_rows, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceReduce::Max(
tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows, stream);
})); }));
auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type()); auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type());
const auto num_edges = coo_rows.size(0); const auto num_edges = coo_rows.size(0);
...@@ -214,19 +206,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -214,19 +206,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto is_nonzero = thrust::make_transform_iterator( auto is_nonzero = thrust::make_transform_iterator(
sliced_probs_or_mask.value().data_ptr<probs_t>(), sliced_probs_or_mask.value().data_ptr<probs_t>(),
IsPositive{}); IsPositive{});
size_t tmp_storage_size = 0; CUB_CALL(
cub::DeviceSegmentedReduce::Sum( DeviceSegmentedReduce::Sum, is_nonzero,
nullptr, tmp_storage_size, is_nonzero,
in_degree.data_ptr<indptr_t>(), num_rows, in_degree.data_ptr<indptr_t>(), num_rows,
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream); sub_indptr.data_ptr<indptr_t>() + 1);
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceSegmentedReduce::Sum(
tmp_storage.get(), tmp_storage_size, is_nonzero,
in_degree.data_ptr<indptr_t>(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream);
})); }));
} }
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
...@@ -235,16 +219,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -235,16 +219,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
in_degree.data_ptr<indptr_t>(), fanouts_device.get(), in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
fanouts.size()}); fanouts.size()});
{ // Compute output_indptr. // Compute output_indptr.
size_t tmp_storage_size = 0; CUB_CALL(
cub::DeviceScan::ExclusiveSum( DeviceScan::ExclusiveSum, sampled_degree,
nullptr, tmp_storage_size, sampled_degree, output_indptr.data_ptr<indptr_t>(), num_rows + 1);
output_indptr.data_ptr<indptr_t>(), num_rows + 1, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceScan::ExclusiveSum(
tmp_storage.get(), tmp_storage_size, sampled_degree,
output_indptr.data_ptr<indptr_t>(), num_rows + 1, stream);
}
auto num_sampled_edges = auto num_sampled_edges =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows}; cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};
...@@ -300,8 +278,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -300,8 +278,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
(num_edges + BLOCK_SIZE - 1) / BLOCK_SIZE); (num_edges + BLOCK_SIZE - 1) / BLOCK_SIZE);
// Compute row and random number pairs. // Compute row and random number pairs.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0, stream, _ComputeRandoms, grid, block, 0, num_edges,
num_edges, sliced_indptr.data_ptr<indptr_t>(), sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr, coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(), indices_ptr, random_seed, randoms.get(),
...@@ -312,21 +290,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -312,21 +290,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Sort the random numbers along with edge ids, after // Sort the random numbers along with edge ids, after
// sorting the first fanout elements of each row will // sorting the first fanout elements of each row will
// give us the sampled edges. // give us the sampled edges.
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceSegmentedSort::SortPairs( DeviceSegmentedSort::SortPairs, randoms.get(),
nullptr, tmp_storage_size, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(), randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges, num_rows, sorted_edge_id_segments.get(), num_edges, num_rows,
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream)); sub_indptr.data_ptr<indptr_t>() + 1);
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairs(
tmp_storage.get(), tmp_storage_size, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges, num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream));
picked_eids = torch::empty( picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges), static_cast<indptr_t>(num_sampled_edges),
...@@ -341,19 +310,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -341,19 +310,11 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto sampled_segment_end_it = thrust::make_transform_iterator( auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{ iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree}); sub_indptr.data_ptr<indptr_t>(), sampled_degree});
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys( DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
nullptr, tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceSegmentedSort::SortKeys(
tmp_storage.get(), tmp_storage_size, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0), sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(), num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it, stream)); sampled_segment_end_it);
} }
auto input_buffer_it = thrust::make_transform_iterator( auto input_buffer_it = thrust::make_transform_iterator(
...@@ -370,17 +331,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -370,17 +331,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Copy the sampled edge ids into picked_eids tensor. // Copy the sampled edge ids into picked_eids tensor.
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceCopy::Batched( DeviceCopy::Batched, input_buffer_it + i,
nullptr, tmp_storage_size, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once), stream));
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceCopy::Batched(
tmp_storage.get(), tmp_storage_size, input_buffer_it + i,
output_buffer_it + i, sampled_degree + i, output_buffer_it + i, sampled_degree + i,
std::min(num_rows - i, max_copy_at_once), stream)); std::min(num_rows - i, max_copy_at_once));
} }
})); }));
...@@ -392,10 +346,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -392,10 +346,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] { indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = index_t; using indices_t = index_t;
const auto exec_policy = THRUST_CALL(
thrust::cuda::par_nosync(allocator).on(stream); gather, picked_eids.data_ptr<indptr_t>(),
thrust::gather(
exec_policy, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0), picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(), indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>()); output_indices.data_ptr<indices_t>());
...@@ -412,10 +364,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -412,10 +364,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
picked_eids.options().dtype(types.scalar_type())); picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] { types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
const auto exec_policy = THRUST_CALL(
thrust::cuda::par_nosync(allocator).on(stream); gather, picked_eids.data_ptr<indptr_t>(),
thrust::gather(
exec_policy, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0), picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(), types.data_ptr<scalar_t>(),
output_type_per_edge.value().data_ptr<scalar_t>()); output_type_per_edge.value().data_ptr<scalar_t>());
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* @file cuda/sampling_utils.cu * @file cuda/sampling_utils.cu
* @brief Sampling utility function implementations on CUDA. * @brief Sampling utility function implementations on CUDA.
*/ */
#include <thrust/execution_policy.h> #include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
...@@ -36,9 +36,6 @@ struct SliceFunc { ...@@ -36,9 +36,6 @@ struct SliceFunc {
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes]) // Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor nodes) {
auto allocator = cuda::GetAllocator();
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(cuda::GetCurrentStream());
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// Read indptr only once in case it is pinned and access is slow. // Read indptr only once in case it is pinned and access is slow.
auto sliced_indptr = auto sliced_indptr =
...@@ -53,8 +50,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( ...@@ -53,8 +50,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "IndexSelectCSCNodes", ([&] { nodes.scalar_type(), "IndexSelectCSCNodes", ([&] {
using nodes_t = index_t; using nodes_t = index_t;
thrust::for_each( THRUST_CALL(
exec_policy, iota, iota + num_nodes, for_each, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{ SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(), nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
in_degree.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),
...@@ -92,9 +89,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero( ...@@ -92,9 +89,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
auto new_sub_indptr = torch::empty(num_rows + 1, sub_indptr.options()); auto new_sub_indptr = torch::empty(num_rows + 1, sub_indptr.options());
auto new_indegree = torch::empty(num_rows + 2, sub_indptr.options()); auto new_indegree = torch::empty(num_rows + 2, sub_indptr.options());
auto new_sliced_indptr = torch::empty(num_rows, sliced_indptr.options()); auto new_sliced_indptr = torch::empty(num_rows, sliced_indptr.options());
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
sub_indptr.scalar_type(), "SliceCSCIndptrHeteroIndptr", ([&] { sub_indptr.scalar_type(), "SliceCSCIndptrHeteroIndptr", ([&] {
...@@ -102,8 +96,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero( ...@@ -102,8 +96,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
etypes.scalar_type(), "SliceCSCIndptrHeteroTypePerEdge", ([&] { etypes.scalar_type(), "SliceCSCIndptrHeteroTypePerEdge", ([&] {
using etype_t = scalar_t; using etype_t = scalar_t;
thrust::for_each( THRUST_CALL(
exec_policy, iota, iota + num_rows, for_each, iota, iota + num_rows,
EdgeTypeSearch<indptr_t, etype_t>{ EdgeTypeSearch<indptr_t, etype_t>{
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(), sliced_indptr.data_ptr<indptr_t>(),
...@@ -111,17 +105,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero( ...@@ -111,17 +105,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
new_sub_indptr.data_ptr<indptr_t>(), new_sub_indptr.data_ptr<indptr_t>(),
new_sliced_indptr.data_ptr<indptr_t>()}); new_sliced_indptr.data_ptr<indptr_t>()});
})); }));
size_t tmp_storage_size = 0; CUB_CALL(
cub::DeviceAdjacentDifference::SubtractLeftCopy( DeviceAdjacentDifference::SubtractLeftCopy,
nullptr, tmp_storage_size, new_sub_indptr.data_ptr<indptr_t>(),
new_indegree.data_ptr<indptr_t>(), num_rows + 1, cub::Difference{},
stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceAdjacentDifference::SubtractLeftCopy(
tmp_storage.get(), tmp_storage_size,
new_sub_indptr.data_ptr<indptr_t>(), new_sub_indptr.data_ptr<indptr_t>(),
new_indegree.data_ptr<indptr_t>(), num_rows + 1, cub::Difference{}, new_indegree.data_ptr<indptr_t>(), num_rows + 1, cub::Difference{});
stream);
})); }));
// Discard the first element of the SubtractLeftCopy result and ensure that // Discard the first element of the SubtractLeftCopy result and ensure that
// new_indegree tensor has size num_rows + 1 so that its ExclusiveCumSum is // new_indegree tensor has size num_rows + 1 so that its ExclusiveCumSum is
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
* @brief Sort implementation on CUDA. * @brief Sort implementation on CUDA.
*/ */
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAStream.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
...@@ -21,8 +20,6 @@ std::conditional_t< ...@@ -21,8 +20,6 @@ std::conditional_t<
torch::Tensor> torch::Tensor>
Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) { Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA); const auto options = torch::TensorOptions().device(c10::DeviceType::CUDA);
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value; constexpr c10::ScalarType dtype = c10::CppTypeToScalarType<scalar_t>::value;
auto sorted_array = torch::empty(num_items, options.dtype(dtype)); auto sorted_array = torch::empty(num_items, options.dtype(dtype));
auto sorted_keys = sorted_array.data_ptr<scalar_t>(); auto sorted_keys = sorted_array.data_ptr<scalar_t>();
...@@ -36,24 +33,14 @@ Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) { ...@@ -36,24 +33,14 @@ Sort(const scalar_t* input_keys, int64_t num_items, int num_bits) {
auto sorted_idx = torch::empty_like(original_idx); auto sorted_idx = torch::empty_like(original_idx);
const int64_t* input_values = original_idx.data_ptr<int64_t>(); const int64_t* input_values = original_idx.data_ptr<int64_t>();
int64_t* sorted_values = sorted_idx.data_ptr<int64_t>(); int64_t* sorted_values = sorted_idx.data_ptr<int64_t>();
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceRadixSort::SortPairs( DeviceRadixSort::SortPairs, input_keys, sorted_keys, input_values,
nullptr, tmp_storage_size, input_keys, sorted_keys, input_values, sorted_values, num_items, 0, num_bits);
sorted_values, num_items, 0, num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys,
input_values, sorted_values, num_items, 0, num_bits, stream));
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} else { } else {
size_t tmp_storage_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceRadixSort::SortKeys( DeviceRadixSort::SortKeys, input_keys, sorted_keys, num_items, 0,
nullptr, tmp_storage_size, input_keys, sorted_keys, num_items, 0, num_bits);
num_bits, stream));
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
CUDA_CALL(cub::DeviceRadixSort::SortKeys(
tmp_storage.get(), tmp_storage_size, input_keys, sorted_keys, num_items,
0, num_bits, stream));
return sorted_array; return sorted_array;
} }
} }
......
...@@ -4,15 +4,11 @@ ...@@ -4,15 +4,11 @@
* @file cuda/unique_and_compact_impl.cu * @file cuda/unique_and_compact_impl.cu
* @brief Unique and compact operator implementation on CUDA. * @brief Unique and compact operator implementation on CUDA.
*/ */
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h> #include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <thrust/functional.h> #include <thrust/functional.h>
#include <thrust/gather.h> #include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/logical.h> #include <thrust/logical.h>
#include <thrust/reduce.h>
#include <thrust/remove.h>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <type_traits> #include <type_traits>
...@@ -33,23 +29,17 @@ struct EqualityFunc { ...@@ -33,23 +29,17 @@ struct EqualityFunc {
} }
}; };
#define DefineReductionFunction(reduce_fn, name) \ #define DefineCubReductionFunction(cub_reduce_fn, name) \
template <typename scalar_iterator_t> \ template <typename scalar_iterator_t> \
auto name(const scalar_iterator_t input, int64_t size) { \ auto name(const scalar_iterator_t input, int64_t size) { \
auto allocator = cuda::GetAllocator(); \ using scalar_t = std::remove_reference_t<decltype(input[0])>; \
auto stream = cuda::GetCurrentStream(); \ cuda::CopyScalar<scalar_t> result; \
using scalar_t = std::remove_reference_t<decltype(input[0])>; \ CUB_CALL(cub_reduce_fn, input, result.get(), size); \
cuda::CopyScalar<scalar_t> result; \ return result; \
size_t workspace_size = 0; \
reduce_fn(nullptr, workspace_size, input, result.get(), size, stream); \
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size); \
reduce_fn( \
tmp_storage.get(), workspace_size, input, result.get(), size, stream); \
return result; \
} }
DefineReductionFunction(cub::DeviceReduce::Max, Max); DefineCubReductionFunction(DeviceReduce::Max, Max);
DefineReductionFunction(cub::DeviceReduce::Min, Min); DefineCubReductionFunction(DeviceReduce::Min, Min);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids, const torch::Tensor src_ids, const torch::Tensor dst_ids,
...@@ -60,7 +50,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -60,7 +50,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
"Dtypes of tensors passed to UniqueAndCompact need to be identical."); "Dtypes of tensors passed to UniqueAndCompact need to be identical.");
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream(); auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
return AT_DISPATCH_INTEGRAL_TYPES( return AT_DISPATCH_INTEGRAL_TYPES(
src_ids.scalar_type(), "unique_and_compact", ([&] { src_ids.scalar_type(), "unique_and_compact", ([&] {
auto src_ids_ptr = src_ids.data_ptr<scalar_t>(); auto src_ids_ptr = src_ids.data_ptr<scalar_t>();
...@@ -84,8 +73,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -84,8 +73,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// Mark dst nodes in the src_ids tensor. // Mark dst nodes in the src_ids tensor.
auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0)); auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
thrust::binary_search( THRUST_CALL(
exec_policy, sorted_unique_dst_ids_ptr, binary_search, sorted_unique_dst_ids_ptr,
sorted_unique_dst_ids_ptr + unique_dst_ids.size(0), src_ids_ptr, sorted_unique_dst_ids_ptr + unique_dst_ids.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), is_dst.get()); src_ids_ptr + src_ids.size(0), is_dst.get());
...@@ -96,16 +85,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -96,16 +85,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto is_src = thrust::make_transform_iterator( auto is_src = thrust::make_transform_iterator(
is_dst.get(), thrust::logical_not<bool>{}); is_dst.get(), thrust::logical_not<bool>{});
cuda::CopyScalar<int64_t> only_src_size; cuda::CopyScalar<int64_t> only_src_size;
size_t workspace_size = 0; CUB_CALL(
cub::DeviceSelect::Flagged( DeviceSelect::Flagged, src_ids_ptr, is_src,
nullptr, workspace_size, src_ids_ptr, is_src,
only_src.data_ptr<scalar_t>(), only_src_size.get(), only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream); src_ids.size(0));
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
cub::DeviceSelect::Flagged(
tmp_storage.get(), workspace_size, src_ids_ptr, is_src,
only_src.data_ptr<scalar_t>(), only_src_size.get(),
src_ids.size(0), stream);
stream.synchronize(); stream.synchronize();
only_src = only_src.slice(0, 0, static_cast<int64_t>(only_src_size)); only_src = only_src.slice(0, 0, static_cast<int64_t>(only_src_size));
} }
...@@ -129,16 +112,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -129,16 +112,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
{ // Compute the unique operation on the only_src tensor. { // Compute the unique operation on the only_src tensor.
cuda::CopyScalar<int64_t> unique_only_src_size; cuda::CopyScalar<int64_t> unique_only_src_size;
size_t workspace_size = 0; CUB_CALL(
CUDA_CALL(cub::DeviceSelect::Unique( DeviceSelect::Unique, sorted_only_src.data_ptr<scalar_t>(),
nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(), unique_only_src_ptr, unique_only_src_size.get(),
unique_only_src_ptr, unique_only_src_size.get(), only_src.size(0), only_src.size(0));
stream));
auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
CUDA_CALL(cub::DeviceSelect::Unique(
tmp_storage.get(), workspace_size,
sorted_only_src.data_ptr<scalar_t>(), unique_only_src_ptr,
unique_only_src_size.get(), only_src.size(0), stream));
stream.synchronize(); stream.synchronize();
unique_only_src = unique_only_src.slice( unique_only_src = unique_only_src.slice(
0, 0, static_cast<int64_t>(unique_only_src_size)); 0, 0, static_cast<int64_t>(unique_only_src_size));
...@@ -146,7 +123,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -146,7 +123,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto real_order = torch::cat({unique_dst_ids, unique_only_src}); auto real_order = torch::cat({unique_dst_ids, unique_only_src});
// Sort here so that binary search can be used to lookup new_ids. // Sort here so that binary search can be used to lookup new_ids.
auto [sorted_order, new_ids] = Sort(real_order, num_bits); torch::Tensor sorted_order, new_ids;
std::tie(sorted_order, new_ids) = Sort(real_order, num_bits);
auto sorted_order_ptr = sorted_order.data_ptr<scalar_t>(); auto sorted_order_ptr = sorted_order.data_ptr<scalar_t>();
auto new_ids_ptr = new_ids.data_ptr<int64_t>(); auto new_ids_ptr = new_ids.data_ptr<int64_t>();
// Holds the found locations of the src and dst ids in the sorted_order. // Holds the found locations of the src and dst ids in the sorted_order.
...@@ -154,8 +132,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -154,8 +132,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
// tensors. // tensors.
auto new_dst_ids_loc = auto new_dst_ids_loc =
allocator.AllocateStorage<scalar_t>(dst_ids.size(0)); allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
thrust::lower_bound( THRUST_CALL(
exec_policy, sorted_order_ptr, lower_bound, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), dst_ids_ptr, sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get()); dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
...@@ -172,16 +150,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -172,16 +150,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
auto new_src_ids_loc = auto new_src_ids_loc =
allocator.AllocateStorage<scalar_t>(src_ids.size(0)); allocator.AllocateStorage<scalar_t>(src_ids.size(0));
thrust::lower_bound( THRUST_CALL(
exec_policy, sorted_order_ptr, lower_bound, sorted_order_ptr,
sorted_order_ptr + sorted_order.size(0), src_ids_ptr, sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
src_ids_ptr + src_ids.size(0), new_src_ids_loc.get()); src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());
// Finally, lookup the new compact ids of the src and dst tensors via // Finally, lookup the new compact ids of the src and dst tensors via
// gather operations. // gather operations.
auto new_src_ids = torch::empty_like(src_ids); auto new_src_ids = torch::empty_like(src_ids);
thrust::gather( THRUST_CALL(
exec_policy, new_src_ids_loc.get(), gather, new_src_ids_loc.get(),
new_src_ids_loc.get() + src_ids.size(0), new_src_ids_loc.get() + src_ids.size(0),
new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>()); new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
// Perform check before we gather for the dst indices. // Perform check before we gather for the dst indices.
...@@ -189,8 +167,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( ...@@ -189,8 +167,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
throw std::out_of_range("Some ids not found."); throw std::out_of_range("Some ids not found.");
} }
auto new_dst_ids = torch::empty_like(dst_ids); auto new_dst_ids = torch::empty_like(dst_ids);
thrust::gather( THRUST_CALL(
exec_policy, new_dst_ids_loc.get(), gather, new_dst_ids_loc.get(),
new_dst_ids_loc.get() + dst_ids.size(0), new_dst_ids_loc.get() + dst_ids.size(0),
new_ids.data_ptr<int64_t>(), new_dst_ids.data_ptr<scalar_t>()); new_ids.data_ptr<int64_t>(), new_dst_ids.data_ptr<scalar_t>());
return std::make_tuple(real_order, new_src_ids, new_dst_ids); return std::make_tuple(real_order, new_src_ids, new_dst_ids);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment