index_select_csc_impl.cu 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_csc_impl.cu
 * @brief Index select csc operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <graphbolt/cuda_ops.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
11
#include <thrust/iterator/zip_iterator.h>
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#include <cub/cub.cuh>
#include <numeric>

#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

// Given the in_degree array and a permutation, returns in_degree of the output
// and the permuted and modified in_degree of the input. The modified in_degree
// is modified so that there is slack to be able to align as needed.
template <typename indptr_t, typename indices_t>
struct AlignmentFunc {
  static_assert(GPU_CACHE_LINE_SIZE % sizeof(indices_t) == 0);
  const indptr_t* in_degree;
  const int64_t* perm;
  int64_t num_nodes;
  __host__ __device__ auto operator()(int64_t row) {
    constexpr int num_elements = GPU_CACHE_LINE_SIZE / sizeof(indices_t);
    return thrust::make_tuple(
        in_degree[row],
        // A single cache line has num_elements items, we add num_elements - 1
        // to ensure there is enough slack to move forward or backward by
        // num_elements - 1 items if the performed access is not aligned.
        (indptr_t)(in_degree[perm ? perm[row % num_nodes] : row] + num_elements - 1));
  }
};

44
template <typename indptr_t, typename indices_t, typename coo_rows_t>
45
__global__ void _CopyIndicesAlignedKernel(
46
47
    const indptr_t edge_count, const indptr_t* const indptr,
    const indptr_t* const output_indptr,
48
    const indptr_t* const output_indptr_aligned, const indices_t* const indices,
49
50
    const coo_rows_t* const coo_aligned_rows, indices_t* const output_indices,
    const int64_t* const perm) {
51
52
53
54
  indptr_t idx = static_cast<indptr_t>(blockIdx.x) * blockDim.x + threadIdx.x;
  const int stride_x = gridDim.x * blockDim.x;

  while (idx < edge_count) {
55
    const auto permuted_row_pos = coo_aligned_rows[idx];
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    const auto row_pos = perm ? perm[permuted_row_pos] : permuted_row_pos;
    const auto out_row = output_indptr[row_pos];
    const auto d = output_indptr[row_pos + 1] - out_row;
    const int offset =
        ((size_t)(indices + indptr[row_pos] - output_indptr_aligned[permuted_row_pos]) %
         GPU_CACHE_LINE_SIZE) /
        sizeof(indices_t);
    const auto rofs = idx - output_indptr_aligned[permuted_row_pos] - offset;
    if (rofs >= 0 && rofs < d) {
      const auto in_idx = indptr[row_pos] + rofs;
      assert((size_t)(indices + in_idx - idx) % GPU_CACHE_LINE_SIZE == 0);
      const auto u = indices[in_idx];
      output_indices[out_row + rofs] = u;
    }
    idx += stride_x;
  }
}

struct PairSum {
  template <typename indptr_t>
  __host__ __device__ auto operator()(
      const thrust::tuple<indptr_t, indptr_t> a,
      const thrust::tuple<indptr_t, indptr_t> b) {
    return thrust::make_tuple(
        thrust::get<0>(a) + thrust::get<0>(b),
        thrust::get<1>(a) + thrust::get<1>(b));
  };
};

template <typename indptr_t, typename indices_t>
std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
    torch::Tensor indices, const int64_t num_nodes,
    const indptr_t* const in_degree, const indptr_t* const sliced_indptr,
89
90
91
    const int64_t* const perm, torch::TensorOptions options,
    torch::ScalarType indptr_scalar_type,
    torch::optional<int64_t> output_size) {
92
93
94
95
96
  auto allocator = cuda::GetAllocator();
  thrust::counting_iterator<int64_t> iota(0);

  // Output indptr for the slice indexed by nodes.
  auto output_indptr =
97
      torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
98
99

  auto output_indptr_aligned =
100
101
      torch::empty(num_nodes + 1, options.dtype(indptr_scalar_type));
  auto output_indptr_aligned_ptr = output_indptr_aligned.data_ptr<indptr_t>();
102
103
104
105
106
107
108
109

  {
    // Returns the actual and modified_indegree as a pair, the
    // latter overestimates the actual indegree for alignment
    // purposes.
    auto modified_in_degree = thrust::make_transform_iterator(
        iota, AlignmentFunc<indptr_t, indices_t>{in_degree, perm, num_nodes});
    auto output_indptr_pair = thrust::make_zip_iterator(
110
        output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr);
111
112
    thrust::tuple<indptr_t, indptr_t> zero_value{};
    // Compute the prefix sum over actual and modified indegrees.
113
114
115
    CUB_CALL(
        DeviceScan::ExclusiveScan, modified_in_degree, output_indptr_pair,
        PairSum{}, zero_value, num_nodes + 1);
116
117
118
  }

  // Copy the actual total number of edges.
119
120
121
122
123
  if (!output_size.has_value()) {
    auto edge_count =
        cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
    output_size = static_cast<indptr_t>(edge_count);
  }
124
  // Copy the modified number of edges.
125
126
127
  auto edge_count_aligned_ =
      cuda::CopyScalar{output_indptr_aligned_ptr + num_nodes};
  const int64_t edge_count_aligned = static_cast<indptr_t>(edge_count_aligned_);
128
129

  // Allocate output array with actual number of edges.
130
131
  torch::Tensor output_indices =
      torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
132
  const dim3 block(BLOCK_SIZE);
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
  const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);

  // Find the smallest integer type to store the coo_aligned_rows tensor.
  const int num_bits = cuda::NumberOfBits(num_nodes);
  std::array<int, 4> type_bits = {8, 15, 31, 63};
  const auto type_index =
      std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
      type_bits.begin();
  std::array<torch::ScalarType, 5> types = {
      torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong};
  auto coo_dtype = types[type_index];

  auto coo_aligned_rows = ExpandIndptrImpl(
      output_indptr_aligned, coo_dtype, torch::nullopt, edge_count_aligned);

  AT_DISPATCH_INTEGRAL_TYPES(
      coo_dtype, "UVAIndexSelectCSCCopyIndicesCOO", ([&] {
        using coo_rows_t = scalar_t;
        // Perform the actual copying, of the indices array into
        // output_indices in an aligned manner.
        CUDA_KERNEL_CALL(
            _CopyIndicesAlignedKernel, grid, block, 0,
            static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
            output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
            reinterpret_cast<indices_t*>(indices.data_ptr()),
            coo_aligned_rows.data_ptr<coo_rows_t>(),
            reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
      }));
161
162
163
164
  return {output_indptr, output_indices};
}

std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
165
166
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::Tensor nodes, int num_bits, torch::optional<int64_t> output_size) {
167
  // Sorting nodes so that accesses over PCI-e are more regular.
168
  const auto sorted_idx = Sort(nodes, num_bits).second;
169
170
171
  const int64_t num_nodes = nodes.size(0);

  return AT_DISPATCH_INTEGRAL_TYPES(
172
      sliced_indptr.scalar_type(), "UVAIndexSelectCSCIndptr", ([&] {
173
174
175
176
        using indptr_t = scalar_t;
        return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
            indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
              return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
177
178
                  indices, num_nodes, in_degree.data_ptr<indptr_t>(),
                  sliced_indptr.data_ptr<indptr_t>(),
179
                  sorted_idx.data_ptr<int64_t>(), nodes.options(),
180
                  sliced_indptr.scalar_type(), output_size);
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            }));
      }));
}

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t, typename indices_t>
struct ConvertToBytes {
  const indptr_t* in_degree;
  __host__ __device__ indptr_t operator()(int64_t i) {
    return in_degree[i] * sizeof(indices_t);
  }
};

template <typename indptr_t, typename indices_t>
void IndexSelectCSCCopyIndices(
    const int64_t num_nodes, indices_t* const indices,
    indptr_t* const sliced_indptr, const indptr_t* const in_degree,
204
    indptr_t* const output_indptr, indices_t* const output_indices) {
205
206
207
208
209
210
211
212
213
214
215
216
  thrust::counting_iterator<int64_t> iota(0);

  auto input_buffer_it = thrust::make_transform_iterator(
      iota, IteratorFunc<indptr_t, indices_t>{sliced_indptr, indices});
  auto output_buffer_it = thrust::make_transform_iterator(
      iota, IteratorFunc<indptr_t, indices_t>{output_indptr, output_indices});
  auto buffer_sizes = thrust::make_transform_iterator(
      iota, ConvertToBytes<indptr_t, indices_t>{in_degree});
  constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();

  // Performs the copy from indices into output_indices.
  for (int64_t i = 0; i < num_nodes; i += max_copy_at_once) {
217
218
219
    CUB_CALL(
        DeviceMemcpy::Batched, input_buffer_it + i, output_buffer_it + i,
        buffer_sizes + i, std::min(num_nodes - i, max_copy_at_once));
220
221
222
  }
}

223
std::tuple<torch::Tensor, torch::Tensor> DeviceIndexSelectCSCImpl(
224
225
226
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::TensorOptions options, torch::optional<int64_t> output_size) {
  const int64_t num_nodes = sliced_indptr.size(0);
227
  return AT_DISPATCH_INTEGRAL_TYPES(
228
      sliced_indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
229
        using indptr_t = scalar_t;
230
231
        auto in_degree_ptr = in_degree.data_ptr<indptr_t>();
        auto sliced_indptr_ptr = sliced_indptr.data_ptr<indptr_t>();
232
233
        // Output indptr for the slice indexed by nodes.
        torch::Tensor output_indptr = torch::empty(
234
            num_nodes + 1, options.dtype(sliced_indptr.scalar_type()));
235

236
237
        // Compute the output indptr, output_indptr.
        CUB_CALL(
238
            DeviceScan::ExclusiveSum, in_degree_ptr,
239
            output_indptr.data_ptr<indptr_t>(), num_nodes + 1);
240
241

        // Number of edges being copied.
242
243
244
245
246
        if (!output_size.has_value()) {
          auto edge_count =
              cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_nodes};
          output_size = static_cast<indptr_t>(edge_count);
        }
247
248
        // Allocate output array of size number of copied edges.
        torch::Tensor output_indices = torch::empty(
249
            output_size.value(), options.dtype(indices.scalar_type()));
250
251
252
253
254
        GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
            indices.element_size(), "IndexSelectCSCCopyIndices", ([&] {
              using indices_t = element_size_t;
              IndexSelectCSCCopyIndices<indptr_t, indices_t>(
                  num_nodes, reinterpret_cast<indices_t*>(indices.data_ptr()),
255
256
                  sliced_indptr_ptr, in_degree_ptr,
                  output_indptr.data_ptr<indptr_t>(),
257
                  reinterpret_cast<indices_t*>(output_indices.data_ptr()));
258
259
260
261
262
            }));
        return std::make_tuple(output_indptr, output_indices);
      }));
}

263
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
264
265
266
    torch::Tensor in_degree, torch::Tensor sliced_indptr, torch::Tensor indices,
    torch::Tensor nodes, int64_t nodes_max,
    torch::optional<int64_t> output_size) {
267
  if (indices.is_pinned()) {
268
269
270
    int num_bits = cuda::NumberOfBits(nodes_max + 1);
    return UVAIndexSelectCSCImpl(
        in_degree, sliced_indptr, indices, nodes, num_bits, output_size);
271
  } else {
272
273
    return DeviceIndexSelectCSCImpl(
        in_degree, sliced_indptr, indices, nodes.options(), output_size);
274
275
276
  }
}

277
278
279
280
281
282
283
284
285
std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
    torch::optional<int64_t> output_size) {
  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
  return IndexSelectCSCImpl(
      in_degree, sliced_indptr, indices, nodes, indptr.size(0) - 2,
      output_size);
}

286
287
}  //  namespace ops
}  //  namespace graphbolt