array_index_select.cu 3.78 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2019 by Contributors
 * \file array/cpu/array_index_select.cu
 * \brief Array index select GPU implementation
 */
#include <dgl/array.h>
7

8
#include "../../runtime/cuda/cuda_common.h"
9
#include "./array_index_select.cuh"
10
#include "./utils.h"
11
12
13
14
15
16

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

17
template <DGLDeviceType XPU, typename DType, typename IdType>
18
NDArray IndexSelect(NDArray array, IdArray index) {
19
  cudaStream_t stream = runtime::getCurrentCUDAStream();
20
21
22
23
  const DType* array_data = static_cast<DType*>(array->data);
  const IdType* idx_data = static_cast<IdType*>(index->data);
  const int64_t arr_len = array->shape[0];
  const int64_t len = index->shape[0];
24
25
26
27
28
29
30
  int64_t num_feat = 1;
  std::vector<int64_t> shape{len};
  for (int d = 1; d < array->ndim; ++d) {
    num_feat *= array->shape[d];
    shape.emplace_back(array->shape[d]);
  }

31
  // use index->ctx for pinned array
32
  NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
33
  if (len == 0) return ret;
34
  DType* ret_data = static_cast<DType*>(ret->data);
35
36

  if (num_feat == 1) {
37
38
39
40
41
    const int nt = cuda::FindNumThreads(len);
    const int nb = (len + nt - 1) / nt;
    CUDA_KERNEL_CALL(
        IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,
        arr_len, ret_data);
42
  } else {
43
44
45
46
47
48
49
50
51
    dim3 block(256, 1);
    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
      block.x /= 2;
      block.y *= 2;
    }
    const dim3 grid((len + block.y - 1) / block.y);
    CUDA_KERNEL_CALL(
        IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,
        idx_data, len, arr_len, ret_data);
52
  }
53
54
55
  return ret;
}

56
57
58
59
template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
60
#ifdef USE_FP16
61
62
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
63
#endif
64
65
66
67
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
68

69
template <DGLDeviceType XPU, typename DType>
70
DType IndexSelect(NDArray array, int64_t index) {
71
  auto device = runtime::DeviceAPI::Get(array->ctx);
72
73
74
75
76
77
78
79
#ifdef USE_FP16
  // The initialization constructor for __half is apparently a device-
  // only function in some setups, but the current function, IndexSelect,
  // isn't run on the device, so it doesn't have access to that constructor.
  using SafeDType = typename std::conditional<
      std::is_same<DType, __half>::value, uint16_t, DType>::type;
  SafeDType ret = 0;
#else
80
  DType ret = 0;
81
#endif
82
  device->CopyDataFromTo(
83
84
85
      static_cast<DType*>(array->data) + index, 0,
      reinterpret_cast<DType*>(&ret), 0, sizeof(DType), array->ctx,
      DGLContext{kDGLCPU, 0}, array->dtype);
86
  return reinterpret_cast<DType&>(ret);
87
88
}

89
90
91
92
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
93
#ifdef USE_FP16
94
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
95
#endif
96
97
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
98
99
100
101

}  // namespace impl
}  // namespace aten
}  // namespace dgl