/*! * Copyright (c) 2020 by Contributors * \file array/cpu/array_cumsum.cu * \brief Array cumsum GPU implementation */ #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" #include "./dgl_cub.cuh" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template IdArray CumSum(IdArray array, bool prepend_zero) { const int64_t len = array.NumElements(); if (len == 0) return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto device = runtime::DeviceAPI::Get(array->ctx); const IdType* in_d = array.Ptr(); IdArray ret; IdType* out_d = nullptr; if (prepend_zero) { ret = aten::Full(0, len + 1, array->dtype.bits, array->ctx); out_d = ret.Ptr() + 1; } else { ret = aten::NewIdArray(len, array->ctx, array->dtype.bits); out_d = ret.Ptr(); } // Allocate workspace size_t workspace_size = 0; cub::DeviceScan::InclusiveSum(nullptr, workspace_size, in_d, out_d, len, thr_entry->stream); void* workspace = device->AllocWorkspace(array->ctx, workspace_size); // Compute cumsum cub::DeviceScan::InclusiveSum(workspace, workspace_size, in_d, out_d, len, thr_entry->stream); device->FreeWorkspace(array->ctx, workspace); return ret; } template IdArray CumSum(IdArray, bool); template IdArray CumSum(IdArray, bool); } // namespace impl } // namespace aten } // namespace dgl