index_select_csc_impl.cu 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_csc_impl.cu
 * @brief Index select csc operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <graphbolt/cuda_ops.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
11
#include <thrust/iterator/zip_iterator.h>
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

#include <cub/cub.cuh>
#include <numeric>

#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

// Given the in_degree array and a permutation, returns in_degree of the output
// and the permuted and modified in_degree of the input. The modified in_degree
// is modified so that there is slack to be able to align as needed.
template <typename indptr_t, typename indices_t>
struct AlignmentFunc {
  static_assert(GPU_CACHE_LINE_SIZE % sizeof(indices_t) == 0);
  const indptr_t* in_degree;
  const int64_t* perm;
  int64_t num_nodes;
  __host__ __device__ auto operator()(int64_t row) {
    constexpr int num_elements = GPU_CACHE_LINE_SIZE / sizeof(indices_t);
    return thrust::make_tuple(
        in_degree[row],
        // A single cache line has num_elements items, we add num_elements - 1
        // to ensure there is enough slack to move forward or backward by
        // num_elements - 1 items if the performed access is not aligned.
        (indptr_t)(in_degree[perm ? perm[row % num_nodes] : row] + num_elements - 1));
  }
};

template <typename indptr_t, typename indices_t>
__global__ void _CopyIndicesAlignedKernel(
    const indptr_t edge_count, const int64_t num_nodes,
    const indptr_t* const indptr, const indptr_t* const output_indptr,
    const indptr_t* const output_indptr_aligned, const indices_t* const indices,
    indices_t* const output_indices, const int64_t* const perm) {
  indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;

  while (idx < edge_count) {
    const auto permuted_row_pos =
        cuda::UpperBound(output_indptr_aligned, num_nodes, idx) - 1;
    const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos;
    const auto out_row = output_indptr[row_pos];
    const auto d = output_indptr[row_pos + 1] - out_row;
    const int offset =
        ((size_t)(indices + indptr[row_pos] - output_indptr_aligned[permuted_row_pos]) %
         GPU_CACHE_LINE_SIZE) /
        sizeof(indices_t);
    const auto rofs = idx - output_indptr_aligned[permuted_row_pos] - offset;
    if (rofs >= 0 && rofs < d) {
      const auto in_idx = indptr[row_pos] + rofs;
      assert((size_t)(indices + in_idx - idx) % GPU_CACHE_LINE_SIZE == 0);
      const auto u = indices[in_idx];
      output_indices[out_row + rofs] = u;
    }
    idx += stride_x;
  }
}

struct PairSum {
  template <typename indptr_t>
  __host__ __device__ auto operator()(
      const thrust::tuple<indptr_t, indptr_t> a,
      const thrust::tuple<indptr_t, indptr_t> b) {
    return thrust::make_tuple(
        thrust::get<0>(a) + thrust::get<0>(b),
        thrust::get<1>(a) + thrust::get<1>(b));
  };
};

template <typename indptr_t, typename indices_t>
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
    torch::Tensor indices, const int64_t num_nodes,
    const indptr_t* const in_degree, const indptr_t* const sliced_indptr,
89
90
91
    const int64_t* const perm, torch::TensorOptions options,
    torch::ScalarType indptr_scalar_type,
    torch::optional<int64_t> output_size) {
92
93
94
95
96
  auto allocator = cuda::GetAllocator();
  thrust::counting_iterator<int64_t> iota(0);

  // Output indptr for the slice indexed by nodes.
  auto output_indptr =
97
      torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
98
99
100
101
102
103
104
105
106
107
108
109
110
111

  auto output_indptr_aligned =
      allocator.AllocateStorage<indptr_t>(num_nodes + 1);

  {
    // Returns the actual and modified_indegree as a pair, the
    // latter overestimates the actual indegree for alignment
    // purposes.
    auto modified_in_degree = thrust::make_transform_iterator(
        iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes});
    auto output_indptr_pair = thrust::make_zip_iterator(
        output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get());
    thrust::tuple<indptr_t, indptr_t> zero_value{};
    // Compute the prefix sum over actual and modified indegrees.
112
113
114
    CUB_CALL(
        DeviceScan::ExclusiveScan, modified_in_degree, output_indptr_pair,
        PairSum{}, zero_value, num_nodes + 1);
115
116
117
  }

  // Copy the actual total number of edges.
118
119
120
121
122
  if (!output_size.has_value()) {
    auto edge_count =
        cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
    output_size = static_cast<indptr_t>(edge_count);
  }
123
124
125
  // Copy the modified number of edges.
  auto edge_count_aligned =
      cuda::CopyScalar{output_indptr_aligned.get() + num_nodes};
126
127

  // Allocate output array with actual number of edges.
128
129
  torch::Tensor output_indices =
      torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
130
  const dim3 block(BLOCK_SIZE);
131
132
133
  const dim3 grid(
      (static_cast<indptr_t>(edge_count_aligned) + BLOCK_SIZE - 1) /
      BLOCK_SIZE);
134
135
136
137

  // Perform the actual copying, of the indices array into
  // output_indices in an aligned manner.
  CUDA_KERNEL_CALL(
138
      _CopyIndicesAlignedKernel, grid, block, 0,
139
140
      static_cast<indptr_t>(edge_count_aligned), num_nodes, sliced_indptr,
      output_indptr.data_ptr<indptr_t>(), output_indptr_aligned.get(),
141
142
143
144
145
146
      reinterpret_cast<indices_t*>(indices.data_ptr()),
      reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
  return {output_indptr, output_indices};
}

std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
147
148
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::Tensor nodes, int num_bits, torch::optional<int64_t> output_size) {
149
  // Sorting nodes so that accesses over PCI-e are more regular.
150
  const auto sorted_idx = Sort(nodes, num_bits).second;
151
152
153
  const int64_t num_nodes = nodes.size(0);

  return AT_DISPATCH_INTEGRAL_TYPES(
154
      sliced_indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
155
156
157
158
        using indptr_t = scalar_t;
        return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
            indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
              return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
159
160
                  indices, num_nodes, in_degree.data_ptr<indptr_t>(),
                  sliced_indptr.data_ptr<indptr_t>(),
161
                  sorted_idx.data_ptr<int64_t>(), nodes.options(),
162
                  sliced_indptr.scalar_type(), output_size);
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            }));
      }));
}

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t, typename indices_t>
struct ConvertToBytes {
  const indptr_t* in_degree;
  __host__ __device__ indptr_t operator()(int64_t i) {
    return in_degree[i] * sizeof(indices_t);
  }
};

template <typename indptr_t, typename indices_t>
void IndexSelectCSCCopyIndices(
    const int64_t num_nodes, indices_t* const indices,
    indptr_t* const sliced_indptr, const indptr_t* const in_degree,
186
    indptr_t* const output_indptr, indices_t* const output_indices) {
187
188
189
190
191
192
193
194
195
196
197
198
  thrust::counting_iterator<int64_t> iota(0);

  auto input_buffer_it = thrust::make_transform_iterator(
      iota, IteratorFunc<indptr_t, indices_t>{sliced_indptr, indices});
  auto output_buffer_it = thrust::make_transform_iterator(
      iota, IteratorFunc<indptr_t, indices_t>{output_indptr, output_indices});
  auto buffer_sizes = thrust::make_transform_iterator(
      iota, ConvertToBytes<indptr_t, indices_t>{in_degree});
  constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();

  // Performs the copy from indices into output_indices.
  for (int64_t i = 0; i < num_nodes; i += max_copy_at_once) {
199
200
201
    CUB_CALL(
        DeviceMemcpy::Batched, input_buffer_it + i, output_buffer_it + i,
        buffer_sizes + i, std::min(num_nodes - i, max_copy_at_once));
202
203
204
  }
}

205
std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
206
207
208
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::TensorOptions options, torch::optional<int64_t> output_size) {
  const int64_t num_nodes = sliced_indptr.size(0);
209
  return AT_DISPATCH_INTEGRAL_TYPES(
210
      sliced_indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
211
        using indptr_t = scalar_t;
212
213
        auto in_degree_ptr = in_degree.data_ptr<indptr_t>();
        auto sliced_indptr_ptr = sliced_indptr.data_ptr<indptr_t>();
214
215
        // Output indptr for the slice indexed by nodes.
        torch::Tensor output_indptr = torch::empty(
216
            num_nodes + 1, options.dtype(sliced_indptr.scalar_type()));
217

218
219
        // Compute the output indptr, output_indptr.
        CUB_CALL(
220
            DeviceScan::ExclusiveSum, in_degree_ptr,
221
            output_indptr.data_ptr<indptr_t>(), num_nodes + 1);
222
223

        // Number of edges being copied.
224
225
226
227
228
        if (!output_size.has_value()) {
          auto edge_count =
              cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
          output_size = static_cast<indptr_t>(edge_count);
        }
229
230
        // Allocate output array of size number of copied edges.
        torch::Tensor output_indices = torch::empty(
231
            output_size.value(), options.dtype(indices.scalar_type()));
232
233
234
235
236
        GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
            indices.element_size(), "IndexSelectCSCCopyIndices", ([&] {
              using indices_t = element_size_t;
              IndexSelectCSCCopyIndices<indptr_t, indices_t>(
                  num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()),
237
238
                  sliced_indptr_ptr, in_degree_ptr,
                  output_indptr.data_ptr<indptr_t>(),
239
                  reinterpret_cast<indices_t*>(output_indices.data_ptr()));
240
241
242
243
244
            }));
        return std::make_tuple(output_indptr, output_indices);
      }));
}

245
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
246
247
248
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::Tensor nodes, int64_t nodes_max,
    torch::optional<int64_t> output_size) {
249
  if (indices.is_pinned()) {
250
251
252
    int num_bits = cuda::NumberOfBits(nodes_max + 1);
    return UVAIndexSelectCSCImpl(
        in_degree, sliced_indptr, indices, nodes, num_bits, output_size);
253
  } else {
254
255
    return DeviceIndexSelectCSCImpl(
        in_degree, sliced_indptr, indices, nodes.options(), output_size);
256
257
258
  }
}

259
260
261
262
263
264
265
266
267
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
    torch::optional<int64_t> output_size) {
  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
  return IndexSelectCSCImpl(
      in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2,
      output_size);
}

268
269
}  //  namespace ops
}  //  namespace graphbolt