array_scatter.cu 2.17 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file array/cuda/array_scatter.cu
 * @brief Array scatter GPU implementation
5
6
 */
#include <dgl/array.h>
7

8
9
10
11
12
13
14
15
16
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

template <typename DType, typename IdType>
17
18
__global__ void _ScatterKernel(
    const IdType* index, const DType* value, int64_t length, DType* out) {
19
20
21
22
23
24
25
26
  int tx = blockIdx.x * blockDim.x + threadIdx.x;
  int stride_x = gridDim.x * blockDim.x;
  while (tx < length) {
    out[index[tx]] = value[tx];
    tx += stride_x;
  }
}

27
template <DGLDeviceType XPU, typename DType, typename IdType>
28
29
30
31
32
33
void Scatter_(IdArray index, NDArray value, NDArray out) {
  const int64_t len = index->shape[0];
  const IdType* idx = index.Ptr<IdType>();
  const DType* val = value.Ptr<DType>();
  DType* outd = out.Ptr<DType>();

34
  cudaStream_t stream = runtime::getCurrentCUDAStream();
35
36
  const int nt = cuda::FindNumThreads(len);
  const int nb = (len + nt - 1) / nt;
37
  CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd);
38
39
}

40
41
42
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
43
#if BF16_ENABLED
44
45
template void Scatter_<kDGLCUDA, __nv_bfloat16, int32_t>(
    IdArray, NDArray, NDArray);
46
#endif  // BF16_ENABLED
47
48
49
50
51
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
52
#if BF16_ENABLED
53
54
template void Scatter_<kDGLCUDA, __nv_bfloat16, int64_t>(
    IdArray, NDArray, NDArray);
55
#endif  // BF16_ENABLED
56
57
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
58
59
60
61

};  // namespace impl
};  // namespace aten
};  // namespace dgl