/*! * Copyright (c) 2020 by Contributors * \file array/cpu/array_sort.cu * \brief Array sort GPU implementation */ #include #include "../../runtime/cuda/cuda_common.h" #include "./utils.h" #include "./dgl_cub.cuh" namespace dgl { using runtime::NDArray; namespace aten { namespace impl { template std::pair Sort(IdArray array, int num_bits) { const auto& ctx = array->ctx; auto device = runtime::DeviceAPI::Get(ctx); const int64_t nitems = array->shape[0]; IdArray orig_idx = Range(0, nitems, 64, ctx); IdArray sorted_array = NewIdArray(nitems, ctx, array->dtype.bits); IdArray sorted_idx = NewIdArray(nitems, ctx, 64); const IdType* keys_in = array.Ptr(); const int64_t* values_in = orig_idx.Ptr(); IdType* keys_out = sorted_array.Ptr(); int64_t* values_out = sorted_idx.Ptr(); hipStream_t stream = runtime::getCurrentCUDAStream(); if (num_bits == 0) { num_bits = sizeof(IdType)*8; } // Allocate workspace size_t workspace_size = 0; CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(nullptr, workspace_size, keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); void* workspace = device->AllocWorkspace(ctx, workspace_size); // Compute CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(workspace, workspace_size, keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); device->FreeWorkspace(ctx, workspace); return std::make_pair(sorted_array, sorted_idx); } template std::pair Sort(IdArray, int num_bits); template std::pair Sort(IdArray, int num_bits); } // namespace impl } // namespace aten } // namespace dgl