array_index_select.cuh 2.78 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021-2022 by Contributors
5
6
 * @file array/cuda/array_index_select.cuh
 * @brief Array index select GPU kernel implementation
7
8
9
10
11
12
13
14
15
16
 */

#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_

namespace dgl {
namespace aten {
namespace impl {

template <typename DType, typename IdType>
17
18
__global__ void IndexSelectSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
19
20
    const int64_t arr_len, DType* out, const int64_t* perm = nullptr) {
  int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
21
22
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
23
    assert(index[tx] >= 0 && index[tx] < arr_len);
24
25
    const auto out_row = perm ? perm[tx] : tx;
    out[out_row] = array[index[tx]];
26
27
28
29
30
31
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
32
    const DType* const array, const int64_t num_feat, const IdType* const index,
33
34
35
    const int64_t length, const int64_t arr_len, DType* const out,
    const int64_t* perm = nullptr) {
  int64_t out_row_index = blockIdx.x * blockDim.y + threadIdx.y;
36

37
  const int64_t stride = blockDim.y * gridDim.x;
38

39
  while (out_row_index < length) {
40
    int64_t col = threadIdx.x;
41
    const int64_t in_row = index[out_row_index];
42
    assert(in_row >= 0 && in_row < arr_len);
43
    const auto out_row = perm ? perm[out_row_index] : out_row_index;
44
    while (col < num_feat) {
45
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
46
47
      col += blockDim.x;
    }
48
    out_row_index += stride;
49
50
51
  }
}

52
template <typename DType, typename IdType>
53
54
55
__global__ void IndexScatterSingleKernel(
    const DType* array, const IdType* index, const int64_t length,
    const int64_t arr_len, DType* out) {
56
57
58
59
60
61
62
63
64
65
66
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    assert(index[tx] >= 0 && index[tx] < arr_len);
    out[index[tx]] = array[tx];
    tx += stride_x;
  }
}

template <typename DType, typename IdType>
__global__ void IndexScatterMultiKernel(
67
68
69
    const DType* const array, const int64_t num_feat, const IdType* const index,
    const int64_t length, const int64_t arr_len, DType* const out) {
  int64_t in_row = blockIdx.x * blockDim.y + threadIdx.y;
70

71
  const int64_t stride = blockDim.y * gridDim.x;
72
73
74
75
76
77

  while (in_row < length) {
    int64_t col = threadIdx.x;
    const int64_t out_row = index[in_row];
    assert(out_row >= 0 && out_row < arr_len);
    while (col < num_feat) {
78
      out[out_row * num_feat + col] = array[in_row * num_feat + col];
79
80
81
82
83
84
      col += blockDim.x;
    }
    in_row += stride;
  }
}

85
86
87
88
}  // namespace impl
}  // namespace aten
}  // namespace dgl

89
#endif  // DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_