/*! * Copyright (c) 2019 by Contributors * \file array/cpu/array_index_select.cu * \brief Array index select GPU implementation */ #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template __global__ void _IndexSelectKernel(const DType* array, const IdType* index, int64_t length, DType* out) { int tx = blockIdx.x * blockDim.x + threadIdx.x; int stride_x = gridDim.x * blockDim.x; while (tx < length) { out[tx] = array[index[tx]]; tx += stride_x; } } template NDArray IndexSelect(NDArray array, IdArray index) { auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); const DType* array_data = static_cast(array->data); const IdType* idx_data = static_cast(index->data); const int64_t arr_len = array->shape[0]; const int64_t len = index->shape[0]; NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx); if (len == 0) return ret; DType* ret_data = static_cast(ret->data); const int nt = cuda::FindNumThreads(len); const int nb = (len + nt - 1) / nt; CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, thr_entry->stream, array_data, idx_data, len, ret_data); return ret; } template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template NDArray IndexSelect(NDArray, IdArray); template DType IndexSelect(NDArray array, int64_t index) { auto device = runtime::DeviceAPI::Get(array->ctx); DType ret = 0; device->CopyDataFromTo( static_cast(array->data) + index, 0, &ret, 0, sizeof(DType), array->ctx, DLContext{kDLCPU, 0}, array->dtype, nullptr); return ret; } template int32_t IndexSelect(NDArray array, int64_t index); template int64_t IndexSelect(NDArray array, int64_t index); template uint32_t IndexSelect(NDArray array, int64_t index); template uint64_t IndexSelect(NDArray array, int64_t index); template float IndexSelect(NDArray array, int64_t index); template double IndexSelect(NDArray array, int64_t index); } // namespace impl } // namespace aten } // namespace dgl