"vscode:/vscode.git/clone" did not exist on "7a587349943325e667866971a36996a56fcff143"
array_index_select_uvm.hip 4.99 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2019-2022 by Contributors
5
6
 * @file array/cuda/uvm/array_index_select_uvm.cu
 * @brief Array index select GPU implementation
7
8
 */
#include <dgl/array.h>
9

10
#include "../../../runtime/cuda/cuda_common.h"
11
#include "../array_index_select.cuh"
12
#include "../utils.h"
sangwzh's avatar
sangwzh committed
13
#include "array_index_select_uvm.cuh"
14
15
16
17
18
19

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

20
template <typename DType, typename IdType>
21
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
sangwzh's avatar
sangwzh committed
22
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
23
24
25
26
27
  const int64_t arr_len = array->shape[0];
  const int64_t len = index->shape[0];
  int64_t num_feat = 1;
  std::vector<int64_t> shape{len};

28
  CHECK(array.IsPinned());
29
  const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));
30
  CHECK_EQ(index->ctx.device_type, kDGLCUDA);
31
32
33
34
35
36
37

  for (int d = 1; d < array->ndim; ++d) {
    num_feat *= array->shape[d];
    shape.emplace_back(array->shape[d]);
  }

  NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
38
  if (len == 0 || arr_len * num_feat == 0) return ret;
39
40
  DType* ret_data = static_cast<DType*>(ret->data);

41
42
43
44
  auto res = Sort(index, cuda::_NumberOfBits(arr_len));
  const IdType* idx_data = static_cast<IdType*>(res.first->data);
  const int64_t* perm_data = static_cast<int64_t*>(res.second->data);

45
  if (num_feat == 1) {
46
47
48
49
    const int nt = cuda::FindNumThreads(len);
    const int nb = (len + nt - 1) / nt;
    CUDA_KERNEL_CALL(
        IndexSelectSingleKernel, nb, nt, 0, stream, array_data, idx_data, len,
50
        arr_len, ret_data, perm_data);
51
  } else {
52
53
54
55
56
57
58
59
60
    dim3 block(256, 1);
    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
      block.x /= 2;
      block.y *= 2;
    }
    const dim3 grid((len + block.y - 1) / block.y);
    if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
      CUDA_KERNEL_CALL(
          IndexSelectMultiKernel, grid, block, 0, stream, array_data, num_feat,
61
          idx_data, len, arr_len, ret_data, perm_data);
62
63
64
    } else {
      CUDA_KERNEL_CALL(
          IndexSelectMultiKernelAligned, grid, block, 0, stream, array_data,
65
          num_feat, idx_data, len, arr_len, ret_data, perm_data);
66
    }
67
68
69
70
  }
  return ret;
}

71
// floating point types are treated as their equal width integer types
72
73
74
75
76
77
78
79
template NDArray IndexSelectCPUFromGPU<int8_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int8_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
80

81
template <typename DType, typename IdType>
82
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
sangwzh's avatar
sangwzh committed
83
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
84
85
86
87
88
89
90
91
  const DType* source_data = static_cast<DType*>(source->data);
  const IdType* idx_data = static_cast<IdType*>(index->data);
  const int64_t arr_len = dest->shape[0];
  const int64_t len = index->shape[0];
  int64_t num_feat = 1;
  std::vector<int64_t> shape{len};

  CHECK(dest.IsPinned());
92
  DType* dest_data = static_cast<DType*>(cuda::GetDevicePointer(dest));
93
94
  CHECK_EQ(index->ctx.device_type, kDGLCUDA);
  CHECK_EQ(source->ctx.device_type, kDGLCUDA);
95
96
97
98
99

  for (int d = 1; d < source->ndim; ++d) {
    num_feat *= source->shape[d];
  }

100
  if (len == 0) return;
101
102

  if (num_feat == 1) {
103
104
105
106
107
    const int nt = cuda::FindNumThreads(len);
    const int nb = (len + nt - 1) / nt;
    CUDA_KERNEL_CALL(
        IndexScatterSingleKernel, nb, nt, 0, stream, source_data, idx_data, len,
        arr_len, dest_data);
108
  } else {
109
110
111
112
113
114
115
116
117
    dim3 block(256, 1);
    while (static_cast<int64_t>(block.x) >= 2 * num_feat) {
      block.x /= 2;
      block.y *= 2;
    }
    const dim3 grid((len + block.y - 1) / block.y);
    CUDA_KERNEL_CALL(
        IndexScatterMultiKernel, grid, block, 0, stream, source_data, num_feat,
        idx_data, len, arr_len, dest_data);
118
119
120
121
122
123
124
125
126
127
128
129
  }
}

// floating point types are treated as their equal width integer types
template void IndexScatterGPUToCPU<int8_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int8_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int16_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int16_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int32_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int32_t, int64_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int64_t, int32_t>(NDArray, IdArray, NDArray);
template void IndexScatterGPUToCPU<int64_t, int64_t>(NDArray, IdArray, NDArray);
130
131
132
133

}  // namespace impl
}  // namespace aten
}  // namespace dgl