index_select_impl.cu 7.53 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
4
5
6
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
7
#include <c10/core/ScalarType.h>
8
#include <c10/cuda/CUDAStream.h>
9
10
11
12
#include <graphbolt/cuda_ops.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
13

14
#include <cub/cub.cuh>
15
16
#include <numeric>

17
#include "./common.h"
18
#include "./utils.h"
19
20
21
22

namespace graphbolt {
namespace ops {

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/** @brief Index select operator implementation for feature size 1. */
template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(
    const DType* input, const int64_t input_len, const IdType* index,
    const int64_t output_len, DType* output,
    const int64_t* permutation = nullptr) {
  int64_t out_row_index = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = gridDim.x * blockDim.x;
  while (out_row_index < output_len) {
    assert(index[out_row_index] >= 0 && index[out_row_index] < input_len);
    const auto out_row =
        permutation ? permutation[out_row_index] : out_row_index;
    output[out_row] = input[index[out_row_index]];
    out_row_index += stride;
  }
}

/**
 * @brief Index select operator implementation for feature size > 1.
 */
43
44
45
46
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
    const DType* const input, const int64_t input_len,
    const int64_t feature_size, const IdType* const index,
47
48
    const int64_t output_len, DType* const output,
    const int64_t* permutation = nullptr) {
49
50
51
52
53
54
55
56
  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;

  const int64_t stride = blockDim.y * gridDim.x;

  while (out_row_index < output_len) {
    int64_t column = threadIdx.x;
    const int64_t in_row = index[out_row_index];
    assert(in_row >= 0 && in_row < input_len);
57
58
    const auto out_row =
        permutation ? permutation[out_row_index] : out_row_index;
59
    while (column < feature_size) {
60
      output[out_row * feature_size + column] =
61
62
63
64
65
66
67
          input[in_row * feature_size + column];
      column += blockDim.x;
    }
    out_row_index += stride;
  }
}

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
/**
 * @brief Index select operator implementation for feature size > 1.
 *
 * @note This is a cross-device access version of IndexSelectMultiKernel. Since
 * the memory access over PCIe is more sensitive to the data access aligment
 * (cacheline), we need a separate version here.
 */
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernelAligned(
    const DType* const input, const int64_t input_len,
    const int64_t feature_size, const IdType* const index,
    const int64_t output_len, DType* const output,
    const int64_t* permutation = nullptr) {
  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;

  const int64_t stride = blockDim.y * gridDim.x;

  while (out_row_index < output_len) {
    int64_t col = threadIdx.x;
    const int64_t in_row = index[out_row_index];
    assert(in_row >= 0 && in_row < input_len);
    const int64_t idx_offset =
        ((uint64_t)(&input[in_row * feature_size]) % GPU_CACHE_LINE_SIZE) /
        sizeof(DType);
    col = col - idx_offset;
    const auto out_row =
        permutation ? permutation[out_row_index] : out_row_index;
    while (col < feature_size) {
      if (col >= 0)
        output[out_row * feature_size + col] =
            input[in_row * feature_size + col];
      col += blockDim.x;
    }
    out_row_index += stride;
  }
}

105
106
107
108
template <typename DType, typename IdType>
torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
  const int64_t input_len = input.size(0);
  const int64_t return_len = index.size(0);
109
110
111
112
  const int64_t original_feature_size = std::accumulate(
      input.sizes().begin() + 1, input.sizes().end(), 1ll, std::multiplies<>());
  const auto aligned_feature_size =
      input.element_size() * original_feature_size / sizeof(DType);
113
  torch::Tensor ret = torch::empty(
114
115
116
117
118
      {return_len, original_feature_size}, torch::TensorOptions()
                                               .dtype(input.dtype())
                                               .device(c10::DeviceType::CUDA));
  DType* input_ptr = reinterpret_cast<DType*>(input.data_ptr());
  DType* ret_ptr = reinterpret_cast<DType*>(ret.data_ptr());
119
120
121

  // Sort the index to improve the memory access pattern.
  torch::Tensor sorted_index, permutation;
122
123
  std::tie(sorted_index, permutation) =
      Sort(index, cuda::NumberOfBits(input_len));
124
125
126
  const IdType* index_sorted_ptr = sorted_index.data_ptr<IdType>();
  const int64_t* permutation_ptr = permutation.data_ptr<int64_t>();

127
  cudaStream_t stream = c10::cuda::getDefaultCUDAStream();
128

129
  if (aligned_feature_size == 1) {
130
131
132
    // Use a single thread to process each output row to avoid wasting threads.
    const int num_threads = cuda::FindNumThreads(return_len);
    const int num_blocks = (return_len + num_threads - 1) / num_threads;
133
134
135
    CUDA_KERNEL_CALL(
        IndexSelectSingleKernel, num_blocks, num_threads, 0, stream, input_ptr,
        input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
136
137
  } else {
    dim3 block(512, 1);
138
    while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
139
140
141
142
      block.x >>= 1;
      block.y <<= 1;
    }
    const dim3 grid((return_len + block.y - 1) / block.y);
143
    if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
144
145
      // When feature size is smaller than GPU cache line size, use unaligned
      // version for less SM usage, which is more resource efficient.
146
147
      CUDA_KERNEL_CALL(
          IndexSelectMultiKernel, grid, block, 0, stream, input_ptr, input_len,
148
149
          aligned_feature_size, index_sorted_ptr, return_len, ret_ptr,
          permutation_ptr);
150
151
    } else {
      // Use aligned version to improve the memory access pattern.
152
153
      CUDA_KERNEL_CALL(
          IndexSelectMultiKernelAligned, grid, block, 0, stream, input_ptr,
154
155
          input_len, aligned_feature_size, index_sorted_ptr, return_len,
          ret_ptr, permutation_ptr);
156
    }
157
  }
158

159
160
161
162
163
164
165
166
167
168
  auto return_shape = std::vector<int64_t>({return_len});
  return_shape.insert(
      return_shape.end(), input.sizes().begin() + 1, input.sizes().end());
  ret = ret.reshape(return_shape);
  return ret;
}

/**
 * @brief UVA index select operator implementation on CUDA.
 *
169
 * All basic torch types are supported for input.
170
171
172
 * The supporting index types are: int, int64_t.
 */
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index) {
173
174
175
176
177
178
179
180
181
182
183
184
185
186
  return AT_DISPATCH_INDEX_TYPES(
      index.scalar_type(), "UVAIndexSelectImpl", ([&] {
        const auto ptr = (size_t)input.data_ptr();
        const int64_t feature_size = std::accumulate(
            input.sizes().begin() + 1, input.sizes().end(), 1ll,
            std::multiplies<>());
        // We perform the copy with datatype of size powers of 2, and the
        // maximum data type we use has 16 bytes. We check the alignment of the
        // pointer and the feature dimensionality to determine the largest
        // type to use for the copy to minimize the number of CUDA threads used.
        // Alignment denotes the maximum suitable alignment and datatype size
        // for the copies.
        const int aligned_access_size =
            std::gcd(16, std::gcd(ptr, input.element_size() * feature_size));
187
188
189
190
        return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
            aligned_access_size, "UVAIndexSelectImplElementSize", ([&] {
              return UVAIndexSelectImpl_<element_size_t, index_t>(input, index);
            }));
191
      }));
192
193
194
195
}

}  //  namespace ops
}  //  namespace graphbolt