/*! * Copyright (c) 2019 by Contributors * \file array/cuda/array_op_impl.cu * \brief Array operator GPU implementation */ #include #include "../../runtime/cuda/cuda_common.h" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { int FindNumThreads(int dim, int max_nthrs) { int ret = max_nthrs; while (ret > dim) { ret = ret >> 1; } return ret; } ///////////////////////////// Range ///////////////////////////// template __global__ void _RangeKernel(IdType* out, IdType low, IdType length) { int tx = blockIdx.x * blockDim.x + threadIdx.x; int stride_x = gridDim.x * blockDim.x; while (tx < length) { out[tx] = low + tx; tx += stride_x; } } template IdArray Range(IdType low, IdType high, DLContext ctx) { CHECK(high >= low) << "high must be bigger than low"; const IdType length = high - low; IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8); if (length == 0) return ret; IdType* ret_data = static_cast(ret->data); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); int nt = FindNumThreads(length, 1024); int nb = (length + nt - 1) / nt; _RangeKernel<<stream>>>(ret_data, low, length); return ret; } template IdArray Range(int32_t, int32_t, DLContext); template IdArray Range(int64_t, int64_t, DLContext); ///////////////////////////// AsNumBits ///////////////////////////// template __global__ void _CastKernel(const InType* in, OutType* out, size_t length) { int tx = blockIdx.x * blockDim.x + threadIdx.x; int stride_x = gridDim.x * blockDim.x; while (tx < length) { out[tx] = in[tx]; tx += stride_x; } } template IdArray AsNumBits(IdArray arr, uint8_t bits) { const std::vector shape(arr->shape, arr->shape + arr->ndim); IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx); const int64_t length = ret.NumElements(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); int nt = FindNumThreads(length, 1024); int nb = (length + nt - 1) / nt; if (bits == 32) { _CastKernel<<stream>>>( static_cast(arr->data), static_cast(ret->data), length); } else { _CastKernel<<stream>>>( static_cast(arr->data), static_cast(ret->data), length); } return ret; } template IdArray AsNumBits(IdArray arr, uint8_t bits); template IdArray AsNumBits(IdArray arr, uint8_t bits); } // namespace impl } // namespace aten } // namespace dgl