array_index_select.cuh 2.7 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021-2022 by Contributors
3
4
 * @file array/cuda/array_index_select.cuh
 * @brief Array index select GPU kernel implementation
5
6
7
8
9
10
11
12
13
14
 */

#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_

namespace dgl {
namespace aten {
namespace impl {

template <typename DType, typename IdType>
15
16
__global__ void IndexSelectSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
17
18
    const int64_t arr_len, DType* out, const int64_t* perm = nullptr) {
  int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
19
20
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
21
    assert(index[tx] >= 0 && index[tx] < arr_len);
22
23
    const auto out_row = perm ? perm[tx] : tx;
    out[out_row] = array[index[tx]];
24
25
26
27
28
29
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
30
    const DType* const array, const int64_t num_feat, const IdType* const index,
31
32
33
    const int64_t length, const int64_t arr_len, DType* const out,
    const int64_t* perm = nullptr) {
  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;
34

35
  const int64_t stride = blockDim.y * gridDim.x;
36

37
  while (out_row_index < length) {
38
    int64_t col = threadIdx.x;
39
    const int64_t in_row = index[out_row_index];
40
    assert(in_row >= 0 && in_row < arr_len);
41
    const auto out_row = perm ? perm[out_row_index] : out_row_index;
42
    while (col < num_feat) {
43
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
44
45
      col += blockDim.x;
    }
46
    out_row_index += stride;
47
48
49
  }
}

50
template <typename DType, typename IdType>
51
52
53
__global__ void IndexScatterSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
    const int64_t arr_len, DType* out) {
54
55
56
57
58
59
60
61
62
63
64
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    assert(index[tx] >= 0 && index[tx] < arr_len);
    out[index[tx]] = array[tx];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexScatterMultiKernel(
65
66
67
    const DType* const array, const int64_t num_feat, const IdType* const index,
    const int64_t length, const int64_t arr_len, DType* const out) {
  int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;
68

69
  const int64_t stride = blockDim.y * gridDim.x;
70
71
72
73
74
75

  while (in_row < length) {
    int64_t col = threadIdx.x;
    const int64_t out_row = index[in_row];
    assert(out_row >= 0 && out_row < arr_len);
    while (col < num_feat) {
76
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
77
78
79
80
81
82
      col += blockDim.x;
    }
    in_row += stride;
  }
}

83
84
85
86
}  // namespace impl
}  // namespace aten
}  // namespace dgl

87
#endif  // DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_