"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "7e311e4bcd9ba7d446d9c88f52143b0cd05ff228"
array_index_select.cuh 2.49 KB
Newer Older
1
/*!
2
3
 *  Copyright (c) 2021-2022 by Contributors
 * \file array/cuda/array_index_select.cuh
4
5
6
7
8
9
10
11
12
13
14
 * \brief Array index select GPU kernel implementation
 */

#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_

namespace dgl {
namespace aten {
namespace impl {

template <typename DType, typename IdType>
15
16
17
__global__ void IndexSelectSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
    const int64_t arr_len, DType* out) {
18
19
20
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
21
    assert(index[tx] >= 0 && index[tx] < arr_len);
22
23
24
25
26
27
28
    out[tx] = array[index[tx]];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
29
30
31
    const DType* const array, const int64_t num_feat, const IdType* const index,
    const int64_t length, const int64_t arr_len, DType* const out) {
  int64_t out_row = blockIdx.x * blockDim.y + threadIdx.y;
32

33
  const int64_t stride = blockDim.y * gridDim.x;
34
35
36
37

  while (out_row < length) {
    int64_t col = threadIdx.x;
    const int64_t in_row = index[out_row];
38
    assert(in_row >= 0 && in_row < arr_len);
39
    while (col < num_feat) {
40
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
41
42
43
44
45
46
      col += blockDim.x;
    }
    out_row += stride;
  }
}

47
template <typename DType, typename IdType>
48
49
50
__global__ void IndexScatterSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
    const int64_t arr_len, DType* out) {
51
52
53
54
55
56
57
58
59
60
61
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    assert(index[tx] >= 0 && index[tx] < arr_len);
    out[index[tx]] = array[tx];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexScatterMultiKernel(
62
63
64
    const DType* const array, const int64_t num_feat, const IdType* const index,
    const int64_t length, const int64_t arr_len, DType* const out) {
  int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;
65

66
  const int64_t stride = blockDim.y * gridDim.x;
67
68
69
70
71
72

  while (in_row < length) {
    int64_t col = threadIdx.x;
    const int64_t out_row = index[in_row];
    assert(out_row >= 0 && out_row < arr_len);
    while (col < num_feat) {
73
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
74
75
76
77
78
79
      col += blockDim.x;
    }
    in_row += stride;
  }
}

80
81
82
83
}  // namespace impl
}  // namespace aten
}  // namespace dgl

84
#endif  // DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_