"app/store/testdata/schema.sql" did not exist on "a4770107a6ea6b4f5adc235d37d08417dc3b9184"
expand_indptr.cu 3.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/expand_indptr.cu
 * @brief ExpandIndptr operator implementation on CUDA.
 */
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include <cub/cub.cuh>
#include <limits>

#include "./common.h"

namespace graphbolt {
namespace ops {

template <typename indices_t, typename nodes_t>
struct RepeatIndex {
  const nodes_t* nodes;
  __host__ __device__ auto operator()(indices_t i) {
    return thrust::make_constant_iterator(nodes ? nodes[i] : i);
  }
};

template <typename indptr_t, typename indices_t>
struct OutputBufferIndexer {
  const indptr_t* indptr;
  indices_t* buffer;
  __host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; }
};

template <typename indptr_t>
struct AdjacentDifference {
  const indptr_t* indptr;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i + 1] - indptr[i];
  }
};

torch::Tensor ExpandIndptrImpl(
    torch::Tensor indptr, torch::ScalarType dtype,
    torch::optional<torch::Tensor> nodes,
    torch::optional<int64_t> output_size) {
  if (!output_size.has_value()) {
    output_size = AT_DISPATCH_INTEGRAL_TYPES(
        indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
          auto indptr_ptr = indptr.data_ptr<scalar_t>();
          auto output_size = cuda::CopyScalar{indptr_ptr + indptr.size(0) - 1};
          return static_cast<scalar_t>(output_size);
        }));
  }
  auto csc_rows =
      torch::empty(output_size.value(), indptr.options().dtype(dtype));

  AT_DISPATCH_INTEGRAL_TYPES(
      indptr.scalar_type(), "ExpandIndptrIndptr", ([&] {
        using indptr_t = scalar_t;
        auto indptr_ptr = indptr.data_ptr<indptr_t>();
        AT_DISPATCH_INTEGRAL_TYPES(
            dtype, "ExpandIndptrIndices", ([&] {
              using indices_t = scalar_t;
              auto csc_rows_ptr = csc_rows.data_ptr<indices_t>();

              auto nodes_dtype = nodes ? nodes.value().scalar_type() : dtype;
              AT_DISPATCH_INTEGRAL_TYPES(
                  nodes_dtype, "ExpandIndptrNodes", ([&] {
                    using nodes_t = scalar_t;
                    auto nodes_ptr =
                        nodes ? nodes.value().data_ptr<nodes_t>() : nullptr;

                    thrust::counting_iterator<int64_t> iota(0);
                    auto input_buffer = thrust::make_transform_iterator(
                        iota, RepeatIndex<indices_t, nodes_t>{nodes_ptr});
                    auto output_buffer = thrust::make_transform_iterator(
                        iota, OutputBufferIndexer<indptr_t, indices_t>{
                                  indptr_ptr, csc_rows_ptr});
                    auto buffer_sizes = thrust::make_transform_iterator(
                        iota, AdjacentDifference<indptr_t>{indptr_ptr});

                    const auto num_rows = indptr.size(0) - 1;
                    constexpr int64_t max_copy_at_once =
                        std::numeric_limits<int32_t>::max();
                    for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
                      CUB_CALL(
                          DeviceCopy::Batched, input_buffer + i,
                          output_buffer + i, buffer_sizes + i,
                          std::min(num_rows - i, max_copy_at_once));
                    }
                  }));
            }));
      }));
  return csc_rows;
}

}  // namespace ops
}  // namespace graphbolt