#include "hip/hip_runtime.h" /*! * Copyright (c) 2021 by Contributors * \file ndarray_partition.h * \brief Operations on partition implemented in CUDA. */ #include "../partition_op.h" #include #include "../../array/cuda/dgl_cub.cuh" #include "../../runtime/cuda/cuda_common.h" #include "../../runtime/workspace.h" using namespace dgl::runtime; namespace dgl { namespace partition { namespace impl { namespace { /** * @brief Kernel to map global element IDs to partition IDs by remainder. * * @tparam IdType The type of ID. * @param global The global element IDs. * @param num_elements The number of element IDs. * @param num_parts The number of partitions. * @param part_id The mapped partition ID (outupt). */ template __global__ void _MapProcByRemainderKernel( const IdType * const global, const int64_t num_elements, const int64_t num_parts, IdType * const part_id) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = blockDim.x*static_cast(blockIdx.x)+threadIdx.x; if (idx < num_elements) { part_id[idx] = global[idx] % num_parts; } } /** * @brief Kernel to map global element IDs to partition IDs, using a bit-mask. * The number of partitions must be a power a two. * * @tparam IdType The type of ID. * @param global The global element IDs. * @param num_elements The number of element IDs. * @param mask The bit-mask with 1's for each bit to keep from the element ID to * extract the partition ID (e.g., an 8 partition mask would be 0x07). * @param part_id The mapped partition ID (outupt). */ template __global__ void _MapProcByMaskRemainderKernel( const IdType * const global, const int64_t num_elements, const IdType mask, IdType * const part_id) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = blockDim.x*static_cast(blockIdx.x)+threadIdx.x; if (idx < num_elements) { part_id[idx] = global[idx] & mask; } } /** * @brief Kernel to map global element IDs to local element IDs. * * @tparam IdType The type of ID. * @param global The global element IDs. * @param num_elements The number of IDs. * @param num_parts The number of partitions. * @param local The local element IDs (output). */ template __global__ void _MapLocalIndexByRemainderKernel( const IdType * const global, const int64_t num_elements, const int num_parts, IdType * const local) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; if (idx < num_elements) { local[idx] = global[idx] / num_parts; } } /** * @brief Kernel to map local element IDs within a partition to their global * IDs, using the remainder over the number of partitions. * * @tparam IdType The type of ID. * @param local The local element IDs. * @param part_id The partition to map local elements from. * @param num_elements The number of elements to map. * @param num_parts The number of partitions. * @param global The global element IDs (output). */ template __global__ void _MapGlobalIndexByRemainderKernel( const IdType * const local, const int part_id, const int64_t num_elements, const int num_parts, IdType * const global) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; assert(part_id < num_parts); if (idx < num_elements) { global[idx] = (local[idx] * num_parts) + part_id; } } /** * @brief Device function to perform a binary search to find to which partition a * given ID belongs. * * @tparam RangeType The type of range. * @param range The prefix-sum of IDs assigned to partitions. * @param num_parts The number of partitions. * @param target The element ID to find the partition of. * * @return The partition. */ template __device__ RangeType _SearchRange( const RangeType * const range, const int num_parts, const RangeType target) { int start = 0; int end = num_parts; int cur = (end+start)/2; assert(range[0] == 0); assert(target < range[num_parts]); while (start+1 < end) { if (target < range[cur]) { end = cur; } else { start = cur; } cur = (start+end)/2; } return cur; } /** * @brief Kernel to map element IDs to partition IDs. * * @tparam IdType The type of element ID. * @tparam RangeType The type of of the range. * @param range The prefix-sum of IDs assigned to partitions. * @param global The global element IDs. * @param num_elements The number of element IDs. * @param num_parts The number of partitions. * @param part_id The partition ID assigned to each element (output). */ template __global__ void _MapProcByRangeKernel( const RangeType * const range, const IdType * const global, const int64_t num_elements, const int64_t num_parts, IdType * const part_id) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = blockDim.x*static_cast(blockIdx.x)+threadIdx.x; // rely on caching to load the range into L1 cache if (idx < num_elements) { part_id[idx] = static_cast(_SearchRange( range, static_cast(num_parts), static_cast(global[idx]))); } } /** * @brief Kernel to map global element IDs to their ID within their respective * partition. * * @tparam IdType The type of element ID. * @tparam RangeType The type of the range. * @param range The prefix-sum of IDs assigned to partitions. * @param global The global element IDs. * @param num_elements The number of elements. * @param num_parts The number of partitions. * @param local The local element IDs (output). */ template __global__ void _MapLocalIndexByRangeKernel( const RangeType * const range, const IdType * const global, const int64_t num_elements, const int num_parts, IdType * const local) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; // rely on caching to load the range into L1 cache if (idx < num_elements) { const int proc = _SearchRange( range, static_cast(num_parts), static_cast(global[idx])); local[idx] = global[idx] - range[proc]; } } /** * @brief Kernel to map local element IDs within a partition to their global * IDs. * * @tparam IdType The type of ID. * @tparam RangeType The type of the range. * @param range The prefix-sum of IDs assigend to partitions. * @param local The local element IDs. * @param part_id The partition to map local elements from. * @param num_elements The number of elements to map. * @param num_parts The number of partitions. * @param global The global element IDs (output). */ template __global__ void _MapGlobalIndexByRangeKernel( const RangeType * const range, const IdType * const local, const int part_id, const int64_t num_elements, const int num_parts, IdType * const global) { assert(num_elements <= gridDim.x*blockDim.x); const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; assert(part_id < num_parts); // rely on caching to load the range into L1 cache if (idx < num_elements) { global[idx] = local[idx] + range[part_id]; } } } // namespace // Remainder Based Partition Operations template std::pair GeneratePermutationFromRemainder( int64_t array_size, int num_parts, IdArray in_idx) { std::pair result; const auto& ctx = in_idx->ctx; auto device = DeviceAPI::Get(ctx); hipStream_t stream = runtime::getCurrentCUDAStream(); const int64_t num_in = in_idx->shape[0]; CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts << ") must be at least 1."; if (num_parts == 1) { // no permutation result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx); result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx); return result; } result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8); result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx); int64_t * out_counts = static_cast(result.second->data); if (num_in == 0) { // now that we've zero'd out_counts, nothing left to do for an empty // mapping return result; } const int64_t part_bits = static_cast(std::ceil(std::log2(num_parts))); // First, generate a mapping of indexes to processors Workspace proc_id_in(device, ctx, num_in); { const dim3 block(256); const dim3 grid((num_in+block.x-1)/block.x); if (num_parts < (1 << part_bits)) { // num_parts is not a power of 2 CUDA_KERNEL_CALL(_MapProcByRemainderKernel, grid, block, 0, stream, static_cast(in_idx->data), num_in, num_parts, proc_id_in.get()); } else { // num_parts is a power of 2 CUDA_KERNEL_CALL(_MapProcByMaskRemainderKernel, grid, block, 0, stream, static_cast(in_idx->data), num_in, static_cast(num_parts-1), // bit mask proc_id_in.get()); } } // then create a permutation array that groups processors together by // performing a radix sort Workspace proc_id_out(device, ctx, num_in); IdType * perm_out = static_cast(result.first->data); { IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx); size_t sort_workspace_size; CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(), static_cast(perm_in->data), perm_out, num_in, 0, part_bits, stream)); Workspace sort_workspace(device, ctx, sort_workspace_size); CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size, proc_id_in.get(), proc_id_out.get(), static_cast(perm_in->data), perm_out, num_in, 0, part_bits, stream)); } // explicitly free so workspace can be re-used proc_id_in.free(); // perform a histogram and then prefixsum on the sorted proc_id vector // Count the number of values to be sent to each processor { using AtomicCount = unsigned long long; // NOLINT static_assert(sizeof(AtomicCount) == sizeof(*out_counts), "AtomicCount must be the same width as int64_t for atomicAdd " "in hipcub::DeviceHistogram::HistogramEven() to work"); // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged, // add a compile time check against the cub version to allow // num_in > (2 << 31). CHECK(num_in < static_cast(std::numeric_limits::max())) << "number of values to insert into histogram must be less than max " "value of int."; size_t hist_workspace_size; CUDA_CALL(hipcub::DeviceHistogram::HistogramEven( nullptr, hist_workspace_size, proc_id_out.get(), reinterpret_cast(out_counts), num_parts+1, static_cast(0), static_cast(num_parts+1), static_cast(num_in), stream)); Workspace hist_workspace(device, ctx, hist_workspace_size); CUDA_CALL(hipcub::DeviceHistogram::HistogramEven( hist_workspace.get(), hist_workspace_size, proc_id_out.get(), reinterpret_cast(out_counts), num_parts+1, static_cast(0), static_cast(num_parts+1), static_cast(num_in), stream)); } return result; } template std::pair GeneratePermutationFromRemainder( int64_t array_size, int num_parts, IdArray in_idx); template std::pair GeneratePermutationFromRemainder( int64_t array_size, int num_parts, IdArray in_idx); template IdArray MapToLocalFromRemainder( const int num_parts, IdArray global_idx) { const auto& ctx = global_idx->ctx; hipStream_t stream = runtime::getCurrentCUDAStream(); if (num_parts > 1) { IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType)*8); const dim3 block(128); const dim3 grid((global_idx->shape[0] +block.x-1)/block.x); CUDA_KERNEL_CALL( _MapLocalIndexByRemainderKernel, grid, block, 0, stream, static_cast(global_idx->data), global_idx->shape[0], num_parts, static_cast(local_idx->data)); return local_idx; } else { // no mapping to be done return global_idx; } } template IdArray MapToLocalFromRemainder( int num_parts, IdArray in_idx); template IdArray MapToLocalFromRemainder( int num_parts, IdArray in_idx); template IdArray MapToGlobalFromRemainder( const int num_parts, IdArray local_idx, const int part_id) { CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id << "/" << num_parts; CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/" << num_parts; const auto& ctx = local_idx->ctx; hipStream_t stream = runtime::getCurrentCUDAStream(); if (num_parts > 1) { IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType)*8); const dim3 block(128); const dim3 grid((local_idx->shape[0] +block.x-1)/block.x); CUDA_KERNEL_CALL( _MapGlobalIndexByRemainderKernel, grid, block, 0, stream, static_cast(local_idx->data), part_id, global_idx->shape[0], num_parts, static_cast(global_idx->data)); return global_idx; } else { // no mapping to be done return local_idx; } } template IdArray MapToGlobalFromRemainder( int num_parts, IdArray in_idx, int part_id); template IdArray MapToGlobalFromRemainder( int num_parts, IdArray in_idx, int part_id); // Range Based Partition Operations template std::pair GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx) { std::pair result; const auto& ctx = in_idx->ctx; auto device = DeviceAPI::Get(ctx); hipStream_t stream = runtime::getCurrentCUDAStream(); const int64_t num_in = in_idx->shape[0]; CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts << ") must be at least 1."; if (num_parts == 1) { // no permutation result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx); result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx); return result; } result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8); result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx); int64_t * out_counts = static_cast(result.second->data); if (num_in == 0) { // now that we've zero'd out_counts, nothing left to do for an empty // mapping return result; } const int64_t part_bits = static_cast(std::ceil(std::log2(num_parts))); // First, generate a mapping of indexes to processors Workspace proc_id_in(device, ctx, num_in); { const dim3 block(256); const dim3 grid((num_in+block.x-1)/block.x); CUDA_KERNEL_CALL(_MapProcByRangeKernel, grid, block, 0, stream, static_cast(range->data), static_cast(in_idx->data), num_in, num_parts, proc_id_in.get()); } // then create a permutation array that groups processors together by // performing a radix sort Workspace proc_id_out(device, ctx, num_in); IdType * perm_out = static_cast(result.first->data); { IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx); size_t sort_workspace_size; CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size, proc_id_in.get(), proc_id_out.get(), static_cast(perm_in->data), perm_out, num_in, 0, part_bits, stream)); Workspace sort_workspace(device, ctx, sort_workspace_size); CUDA_CALL(hipcub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size, proc_id_in.get(), proc_id_out.get(), static_cast(perm_in->data), perm_out, num_in, 0, part_bits, stream)); } // explicitly free so workspace can be re-used proc_id_in.free(); // perform a histogram and then prefixsum on the sorted proc_id vector // Count the number of values to be sent to each processor { using AtomicCount = unsigned long long; // NOLINT static_assert(sizeof(AtomicCount) == sizeof(*out_counts), "AtomicCount must be the same width as int64_t for atomicAdd " "in hipcub::DeviceHistogram::HistogramEven() to work"); // TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged, // add a compile time check against the cub version to allow // num_in > (2 << 31). CHECK(num_in < static_cast(std::numeric_limits::max())) << "number of values to insert into histogram must be less than max " "value of int."; size_t hist_workspace_size; CUDA_CALL(hipcub::DeviceHistogram::HistogramEven( nullptr, hist_workspace_size, proc_id_out.get(), reinterpret_cast(out_counts), num_parts+1, static_cast(0), static_cast(num_parts+1), static_cast(num_in), stream)); Workspace hist_workspace(device, ctx, hist_workspace_size); CUDA_CALL(hipcub::DeviceHistogram::HistogramEven( hist_workspace.get(), hist_workspace_size, proc_id_out.get(), reinterpret_cast(out_counts), num_parts+1, static_cast(0), static_cast(num_parts+1), static_cast(num_in), stream)); } return result; } template std::pair GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx); template std::pair GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx); template std::pair GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx); template std::pair GeneratePermutationFromRange( int64_t array_size, int num_parts, IdArray range, IdArray in_idx); template IdArray MapToLocalFromRange( const int num_parts, IdArray range, IdArray global_idx) { const auto& ctx = global_idx->ctx; hipStream_t stream = runtime::getCurrentCUDAStream(); if (num_parts > 1 && global_idx->shape[0] > 0) { IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx, sizeof(IdType)*8); const dim3 block(128); const dim3 grid((global_idx->shape[0] +block.x-1)/block.x); CUDA_KERNEL_CALL( _MapLocalIndexByRangeKernel, grid, block, 0, stream, static_cast(range->data), static_cast(global_idx->data), global_idx->shape[0], num_parts, static_cast(local_idx->data)); return local_idx; } else { // no mapping to be done return global_idx; } } template IdArray MapToLocalFromRange( int num_parts, IdArray range, IdArray in_idx); template IdArray MapToLocalFromRange( int num_parts, IdArray range, IdArray in_idx); template IdArray MapToLocalFromRange( int num_parts, IdArray range, IdArray in_idx); template IdArray MapToLocalFromRange( int num_parts, IdArray range, IdArray in_idx); template IdArray MapToGlobalFromRange( const int num_parts, IdArray range, IdArray local_idx, const int part_id) { CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id << "/" << num_parts; CHECK_GE(part_id, 0) << "Invalid partition id " << part_id << "/" << num_parts; const auto& ctx = local_idx->ctx; hipStream_t stream = runtime::getCurrentCUDAStream(); if (num_parts > 1 && local_idx->shape[0] > 0) { IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx, sizeof(IdType)*8); const dim3 block(128); const dim3 grid((local_idx->shape[0] +block.x-1)/block.x); CUDA_KERNEL_CALL( _MapGlobalIndexByRangeKernel, grid, block, 0, stream, static_cast(range->data), static_cast(local_idx->data), part_id, global_idx->shape[0], num_parts, static_cast(global_idx->data)); return global_idx; } else { // no mapping to be done return local_idx; } } template IdArray MapToGlobalFromRange( int num_parts, IdArray range, IdArray in_idx, int part_id); template IdArray MapToGlobalFromRange( int num_parts, IdArray range, IdArray in_idx, int part_id); template IdArray MapToGlobalFromRange( int num_parts, IdArray range, IdArray in_idx, int part_id); template IdArray MapToGlobalFromRange( int num_parts, IdArray range, IdArray in_idx, int part_id); } // namespace impl } // namespace partition } // namespace dgl