expand_indptr.hip 3.65 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/expand_indptr.cu
 * @brief ExpandIndptr operator implementation on CUDA.
 */
sangwzh's avatar
sangwzh committed
8
#include <hip/hip_runtime.h>
9
10
11
12
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

sangwzh's avatar
sangwzh committed
13
#include <hipcub/hipcub.hpp>
14
#include <limits>
sangwzh's avatar
sangwzh committed
15
#include <hipcub/backend/rocprim/device/device_copy.hpp>
sangwzh's avatar
sangwzh committed
16
#include "common.h"
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

namespace graphbolt {
namespace ops {

template <typename indices_t, typename nodes_t>
struct RepeatIndex {
  const nodes_t* nodes;
  __host__ __device__ auto operator()(indices_t i) {
    return thrust::make_constant_iterator(nodes ? nodes[i] : i);
  }
};

template <typename indptr_t, typename indices_t>
struct OutputBufferIndexer {
  const indptr_t* indptr;
  indices_t* buffer;
  __host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; }
};

template <typename indptr_t>
struct AdjacentDifference {
  const indptr_t* indptr;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i + 1] - indptr[i];
  }
};

torch::Tensor ExpandIndptrImpl(
    torch::Tensor indptr, torch::ScalarType dtype,
    torch::optional<torch::Tensor> nodes,
    torch::optional<int64_t> output_size) {
  if (!output_size.has_value()) {
    output_size = AT_DISPATCH_INTEGRAL_TYPES(
        indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
          auto indptr_ptr = indptr.data_ptr<scalar_t>();
          auto output_size = cuda::CopyScalar{indptr_ptr + indptr.size(0) - 1};
          return static_cast<scalar_t>(output_size);
        }));
  }
  auto csc_rows =
      torch::empty(output_size.value(), indptr.options().dtype(dtype));

  AT_DISPATCH_INTEGRAL_TYPES(
      indptr.scalar_type(), "ExpandIndptrIndptr", ([&] {
        using indptr_t = scalar_t;
        auto indptr_ptr = indptr.data_ptr<indptr_t>();
        AT_DISPATCH_INTEGRAL_TYPES(
            dtype, "ExpandIndptrIndices", ([&] {
              using indices_t = scalar_t;
              auto csc_rows_ptr = csc_rows.data_ptr<indices_t>();

              auto nodes_dtype = nodes ? nodes.value().scalar_type() : dtype;
              AT_DISPATCH_INTEGRAL_TYPES(
                  nodes_dtype, "ExpandIndptrNodes", ([&] {
                    using nodes_t = scalar_t;
                    auto nodes_ptr =
                        nodes ? nodes.value().data_ptr<nodes_t>() : nullptr;

                    thrust::counting_iterator<int64_t> iota(0);
                    auto input_buffer = thrust::make_transform_iterator(
                        iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});
                    auto output_buffer = thrust::make_transform_iterator(
                        iota, OutputBufferIndexer<indptr_t, indices_t>{
                                  indptr_ptr, csc_rows_ptr});
                    auto buffer_sizes = thrust::make_transform_iterator(
                        iota, AdjacentDifference<indptr_t>{indptr_ptr});

                    const auto num_rows = indptr.size(0) - 1;
                    constexpr int64_t max_copy_at_once =
                        std::numeric_limits<int32_t>::max();
                    for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
                      CUB_CALL(
                          DeviceCopy::Batched, input_buffer + i,
                          output_buffer + i, buffer_sizes + i,
sangwzh's avatar
sangwzh committed
91
                          ::min(num_rows - i, max_copy_at_once));
92
93
94
95
96
97
98
99
100
                    }
                  }));
            }));
      }));
  return csc_rows;
}

}  // namespace ops
}  // namespace graphbolt