array_index_select.cuh 2.79 KB
Newer Older
1
/*!
2
3
 *  Copyright (c) 2021-2022 by Contributors
 * \file array/cuda/array_index_select.cuh
4
5
6
7
8
9
10
11
12
13
14
 * \brief Array index select GPU kernel implementation
 */

#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_

namespace dgl {
namespace aten {
namespace impl {

template <typename DType, typename IdType>
15
16
17
18
19
__global__ void IndexSelectSingleKernel(const DType* array,
                                        const IdType* index,
                                        const int64_t length,
                                        const int64_t arr_len,
                                        DType* out) {
20
21
22
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
23
    assert(index[tx] >= 0 && index[tx] < arr_len);
24
25
26
27
28
29
30
    out[tx] = array[index[tx]];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
31
        const DType* const array,
32
33
34
        const int64_t num_feat,
        const IdType* const index,
        const int64_t length,
35
        const int64_t arr_len,
36
37
38
39
40
41
42
43
        DType* const out) {
  int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;

  const int64_t stride = blockDim.y*gridDim.x;

  while (out_row < length) {
    int64_t col = threadIdx.x;
    const int64_t in_row = index[out_row];
44
    assert(in_row >= 0 && in_row < arr_len);
45
46
47
48
49
50
51
52
    while (col < num_feat) {
      out[out_row*num_feat+col] = array[in_row*num_feat+col];
      col += blockDim.x;
    }
    out_row += stride;
  }
}

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
template <typename DType, typename IdType>
__global__ void IndexScatterSingleKernel(const DType* array,
                                         const IdType* index,
                                         const int64_t length,
                                         const int64_t arr_len,
                                         DType* out) {
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    assert(index[tx] >= 0 && index[tx] < arr_len);
    out[index[tx]] = array[tx];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexScatterMultiKernel(
        const DType* const array,
        const int64_t num_feat,
        const IdType* const index,
        const int64_t length,
        const int64_t arr_len,
        DType* const out) {
  int64_t in_row = blockIdx.x*blockDim.y+threadIdx.y;

  const int64_t stride = blockDim.y*gridDim.x;

  while (in_row < length) {
    int64_t col = threadIdx.x;
    const int64_t out_row = index[in_row];
    assert(out_row >= 0 && out_row < arr_len);
    while (col < num_feat) {
      out[out_row*num_feat+col] = array[in_row*num_feat+col];
      col += blockDim.x;
    }
    in_row += stride;
  }
}

92
93
94
95
96
}  // namespace impl
}  // namespace aten
}  // namespace dgl

#endif