array_scatter.hip 2.31 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2019 by Contributors
5
6
 * @file array/cuda/array_scatter.cu
 * @brief Array scatter GPU implementation
7
8
 */
#include <dgl/array.h>
sangwzh's avatar
sangwzh committed
9
10
#include "../../../include/dgl/array.h"

11

12
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
13
#include "utils.h"
14
15
16
17
18
19
20

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

template <typename DType, typename IdType>
21
22
__global__ void _ScatterKernel(
    const IdType* index, const DType* value, int64_t length, DType* out) {
23
24
25
26
27
28
29
30
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    out[index[tx]] = value[tx];
    tx += stride_x;
  }
}

31
template <DGLDeviceType XPU, typename DType, typename IdType>
32
33
34
35
36
37
void Scatter_(IdArray index, NDArray value, NDArray out) {
  const int64_t len = index->shape[0];
  const IdType* idx = index.Ptr<IdType>();
  const DType* val = value.Ptr<DType>();
  DType* outd = out.Ptr<DType>();

sangwzh's avatar
sangwzh committed
38
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
39
40
  const int nt = cuda::FindNumThreads(len);
  const int nb = (len + nt - 1) / nt;
41
  CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
42
43
}

44
45
46
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
47
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
48
template void Scatter_<kDGLCUDA, __hip_bfloat16, int32_t>(
49
    IdArray, NDArray, NDArray);
50
#endif  // BF16_ENABLED
51
52
53
54
55
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
56
#if BF16_ENABLED
sangwzh's avatar
sangwzh committed
57
template void Scatter_<kDGLCUDA, __hip_bfloat16, int64_t>(
58
    IdArray, NDArray, NDArray);
59
#endif  // BF16_ENABLED
60
61
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
62
63
64
65

};  // namespace impl
};  // namespace aten
};  // namespace dgl