#include "hip/hip_runtime.h" /*! * Copyright (c) 2019 by Contributors * \file array/cuda/array_scatter.cu * \brief Array scatter GPU implementation */ #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template __global__ void _ScatterKernel(const IdType* index, const DType* value, int64_t length, DType* out) { int tx = blockIdx.x * blockDim.x + threadIdx.x; int stride_x = gridDim.x * blockDim.x; while (tx < length) { out[index[tx]] = value[tx]; tx += stride_x; } } template void Scatter_(IdArray index, NDArray value, NDArray out) { const int64_t len = index->shape[0]; const IdType* idx = index.Ptr(); const DType* val = value.Ptr(); DType* outd = out.Ptr(); hipStream_t stream = runtime::getCurrentCUDAStream(); const int nt = cuda::FindNumThreads(len); const int nb = (len + nt - 1) / nt; CUDA_KERNEL_CALL(_ScatterKernel, nb, nt, 0, stream, idx, val, len, outd); } template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); #ifdef USE_FP16 template void Scatter_(IdArray, NDArray, NDArray); #endif template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); #ifdef USE_FP16 template void Scatter_(IdArray, NDArray, NDArray); #endif template void Scatter_(IdArray, NDArray, NDArray); template void Scatter_(IdArray, NDArray, NDArray); }; // namespace impl }; // namespace aten }; // namespace dgl