cumsum.cu 803 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/cumsum.cu
 * @brief Cumsum operators implementation on CUDA.
 */
#include <cub/cub.cuh>

#include "./common.h"

namespace graphbolt {
namespace ops {

torch::Tensor ExclusiveCumSum(torch::Tensor input) {
  auto result = torch::empty_like(input);

17
18
19
20
21
22
  AT_DISPATCH_INTEGRAL_TYPES(input.scalar_type(), "ExclusiveCumSum", ([&] {
                               CUB_CALL(
                                   DeviceScan::ExclusiveSum,
                                   input.data_ptr<scalar_t>(),
                                   result.data_ptr<scalar_t>(), input.size(0));
                             }));
23
24
25
26
27
  return result;
}

}  // namespace ops
}  // namespace graphbolt