array_index_select_uvm.cu 4.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/*!
 *  Copyright (c) 2019-2022 by Contributors
 * \file array/cuda/uvm/array_index_select_uvm.cu
 * \brief Array index select GPU implementation
 */
#include <dgl/array.h>
#include "../../../runtime/cuda/cuda_common.h"
#include "../array_index_select.cuh"
#include "./array_index_select_uvm.cuh"
#include "../utils.h"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

template<typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
  hipStream_t stream = runtime::getCurrentCUDAStream();
  const DType* array_data = static_cast<DType*>(array->data);
  const IdType* idx_data = static_cast<IdType*>(index->data);
  const int64_t arr_len = array->shape[0];
  const int64_t len = index->shape[0];
  int64_t num_feat = 1;
  std::vector<int64_t> shape{len};

  CHECK(array.IsPinned());
lisj's avatar
lisj committed
28
  CHECK_EQ(index->ctx.device_type, kDLROCM);
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

  for (int d = 1; d < array->ndim; ++d) {
    num_feat *= array->shape[d];
    shape.emplace_back(array->shape[d]);
  }

  NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
  if (len == 0)
    return ret;
  DType* ret_data = static_cast<DType*>(ret->data);

  if (num_feat == 1) {
      const int nt = cuda::FindNumThreads(len);
      const int nb = (len + nt - 1) / nt;
      CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0,
          stream, array_data, idx_data, len, arr_len, ret_data);
  } else {
      dim3 block(256, 1);
      while (static_cast<int64_t>(block.x) >= 2*num_feat) {
          block.x /= 2;
          block.y *= 2;
      }
      const dim3 grid((len+block.y-1)/block.y);
      if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
        CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
            stream, array_data, num_feat, idx_data,
            len, arr_len, ret_data);
      } else {
        CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
            stream, array_data, num_feat, idx_data,
            len, arr_len, ret_data);
      }
  }
  return ret;
}

// floating point types are treated as their equal width integer types
template NDArray IndexSelectCPUFromGPU<int8_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int8_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);


template<typename DType, typename IdType>
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
  hipStream_t stream = runtime::getCurrentCUDAStream();
  DType* dest_data = static_cast<DType*>(dest->data);
  const DType* source_data = static_cast<DType*>(source->data);
  const IdType* idx_data = static_cast<IdType*>(index->data);
  const int64_t arr_len = dest->shape[0];
  const int64_t len = index->shape[0];
  int64_t num_feat = 1;
  std::vector<int64_t> shape{len};

  CHECK(dest.IsPinned());
lisj's avatar
lisj committed
88
89
  CHECK_EQ(index->ctx.device_type, kDLROCM);
  CHECK_EQ(source->ctx.device_type, kDLROCM);
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

  for (int d = 1; d < source->ndim; ++d) {
    num_feat *= source->shape[d];
  }

  if (len == 0)
    return;

  if (num_feat == 1) {
      const int nt = cuda::FindNumThreads(len);
      const int nb = (len + nt - 1) / nt;
      CUDA_KERNEL_CALL(IndexScatterSingleKernel, nb, nt, 0,
          stream, source_data, idx_data, len, arr_len, dest_data);
  } else {
      dim3 block(256, 1);
      while (static_cast<int64_t>(block.x) >= 2*num_feat) {
          block.x /= 2;
          block.y *= 2;
      }
      const dim3 grid((len+block.y-1)/block.y);
      CUDA_KERNEL_CALL(IndexScatterMultiKernel, grid, block, 0,
          stream, source_data, num_feat, idx_data,
          len, arr_len, dest_data);
  }
}

// floating point types are treated as their equal width integer types
template void IndexScatterGPUToCPU<int8_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int8_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int16_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int16_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int32_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int32_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int64_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int64_t, int64_t>(NDArray, IdArray, NDArray);

}  // namespace impl
}  // namespace aten
}  // namespace dgl