Unverified Commit ceef30b4 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Adds an exclusive prefix sum function for Neighbor Sampling. (#6798)

parent 869bfb67
...@@ -12,6 +12,15 @@ namespace ops { ...@@ -12,6 +12,15 @@ namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort(torch::Tensor input, int num_bits); std::pair<torch::Tensor, torch::Tensor> Sort(torch::Tensor input, int num_bits);
/**
* @brief Computes the exclusive prefix sum of the given input.
*
* @param input The input tensor.
*
* @return The prefix sum result such that r[i] = \sum_{j=0}^{i-1} input[j]
*/
torch::Tensor ExclusiveCumSum(torch::Tensor input);
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes); torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes);
......
...@@ -67,6 +67,8 @@ struct CUDAWorkspaceAllocator { ...@@ -67,6 +67,8 @@ struct CUDAWorkspaceAllocator {
inline auto GetAllocator() { return CUDAWorkspaceAllocator{}; } inline auto GetAllocator() { return CUDAWorkspaceAllocator{}; }
inline auto GetCurrentStream() { return c10::cuda::getCurrentCUDAStream(); }
template <typename T> template <typename T>
inline bool is_zero(T size) { inline bool is_zero(T size) {
return size == 0; return size == 0;
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/cumsum.cu
* @brief Cumsum operators implementation on CUDA.
*/
#include <cub/cub.cuh>
#include "./common.h"
namespace graphbolt {
namespace ops {
torch::Tensor ExclusiveCumSum(torch::Tensor input) {
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
auto result = torch::empty_like(input);
AT_DISPATCH_INTEGRAL_TYPES(
input.scalar_type(), "ExclusiveCumSum", ([&] {
size_t tmp_storage_size = 0;
cub::DeviceScan::ExclusiveSum(
nullptr, tmp_storage_size, input.data_ptr<scalar_t>(),
result.data_ptr<scalar_t>(), input.size(0), stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceScan::ExclusiveSum(
tmp_storage.get(), tmp_storage_size, input.data_ptr<scalar_t>(),
result.data_ptr<scalar_t>(), input.size(0), stream);
}));
return result;
}
} // namespace ops
} // namespace graphbolt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment