/*! * Copyright (c) 2021 by Contributors * \file graph/sampling/frequency_hashmap.cu * \brief frequency hashmap - used to select top-k frequency edges of each node */ #include #include #include #include #include "../../../runtime/cuda/cuda_common.h" #include "../../../array/cuda/atomic.cuh" #include "frequency_hashmap.cuh" namespace dgl { namespace sampling { namespace impl { namespace { int64_t _table_size(const int64_t num, const int64_t scale) { /** * Calculate the number of buckets in the hashtable. To guarantee we can * fill the hashtable in the worst case, we must use a number of buckets which * is a power of two. * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations */ const int64_t next_pow2 = 1 << static_cast(1 + std::log2(num >> 1)); return next_pow2 << scale; } template __global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) { using EdgeItem = typename DeviceEdgeHashmap::EdgeItem; auto edge_hashmap_t = static_cast(edge_hashmap); int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t last_idx = start_idx + TILE_SIZE; #pragma unroll(4) for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) { if (idx < edges_len) { EdgeItem *edge = (edge_hashmap_t + idx); edge->src = static_cast(-1); edge->cnt = static_cast(0); } } } template __global__ void _count_frequency(const IdxType *src_data, const int64_t num_edges, const int64_t num_edges_per_node, IdxType *edge_blocks_prefix, bool *is_first_position, DeviceEdgeHashmap device_edge_hashmap) { int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t last_idx = start_idx + TILE_SIZE; IdxType count = 0; for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) { if (idx < num_edges) { IdxType src = src_data[idx]; if (src == static_cast(-1)) { continue; } IdxType dst_idx = (idx / num_edges_per_node); if (device_edge_hashmap.InsertEdge(src, dst_idx) == 0) { is_first_position[idx] = true; ++count; } } } using BlockReduce = typename cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_space; count = BlockReduce(temp_space).Sum(count); if (threadIdx.x == 0) { edge_blocks_prefix[blockIdx.x] = count; if (blockIdx.x == 0) { edge_blocks_prefix[gridDim.x] = 0; } } } /** * This structure is used with cub's block-level prefixscan in order to * keep a running sum as items are iteratively processed. */ template struct BlockPrefixCallbackOp { T _running_total; __device__ BlockPrefixCallbackOp(const T running_total) : _running_total(running_total) {} __device__ T operator()(const T block_aggregate) { const T old_prefix = _running_total; _running_total += block_aggregate; return old_prefix; } }; template __global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_data, const int64_t num_edges, const int64_t num_edges_per_node, const IdxType *edge_blocks_prefix, const bool *is_first_position, IdxType *num_unique_each_node, IdxType *unique_src_edges, Idx64Type *unique_frequency, DeviceEdgeHashmap device_edge_hashmap) { int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t last_idx = start_idx + TILE_SIZE; const IdxType block_offset = edge_blocks_prefix[blockIdx.x]; using BlockScan = typename cub::BlockScan; __shared__ typename BlockScan::TempStorage temp_space; BlockPrefixCallbackOp prefix_op(0); for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) { IdxType flag = 0; if (idx < num_edges) { IdxType src = src_data[idx]; IdxType dst_idx = (idx / num_edges_per_node); if (idx % num_edges_per_node == 0) { num_unique_each_node[dst_idx] = device_edge_hashmap.GetDstCount(dst_idx); } if (is_first_position[idx] == true) { flag = 1; } BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op); __syncthreads(); if (is_first_position[idx] == true) { const IdxType pos = (block_offset + flag); unique_src_edges[pos] = src; if (sizeof(IdxType) != sizeof(Idx64Type) && sizeof(IdxType) == 4) { // if IdxType is a 32-bit data unique_frequency[pos] = ( (static_cast(num_edges / num_edges_per_node - dst_idx) << 32) | device_edge_hashmap.GetEdgeCount(src, dst_idx)); } else { unique_frequency[pos] = device_edge_hashmap.GetEdgeCount(src, dst_idx); } } } } } template __global__ void _get_pick_num(IdxType *num_unique_each_node, const int64_t num_pick, const int64_t num_dst_nodes) { int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t last_idx = start_idx + TILE_SIZE; #pragma unroll(4) for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) { if (idx < num_dst_nodes) { IdxType &num_unique = num_unique_each_node[idx]; num_unique = min(num_unique, static_cast(num_pick)); } } } template __global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *unique_src_edges, const IdxType *unique_input_offsets, const IdxType *dst_data, const int64_t num_edges_per_node, const int64_t num_dst_nodes, const int64_t num_edges, const IdxType *unique_output_offsets, IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) { int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t last_idx = start_idx + TILE_SIZE; for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) { if (idx < num_dst_nodes) { const int64_t dst_pos = (idx * num_edges_per_node); assert(dst_pos < num_edges); const IdxType dst = dst_data[dst_pos]; const IdxType last_output_offset = unique_output_offsets[idx + 1]; assert((last_output_offset - unique_output_offsets[idx]) <= (unique_input_offsets[idx + 1] - unique_input_offsets[idx])); for (IdxType output_idx = unique_output_offsets[idx], input_idx = unique_input_offsets[idx]; output_idx < last_output_offset; ++output_idx, ++input_idx) { output_src[output_idx] = unique_src_edges[input_idx]; output_dst[output_idx] = dst; output_frequency[output_idx] = static_cast(unique_frequency[input_idx]); } } } } } // namespace // return the old cnt of this edge template inline __device__ IdxType DeviceEdgeHashmap::InsertEdge( const IdxType &src, const IdxType &dst_idx) { IdxType start_off = dst_idx * _num_items_each_dst; IdxType pos = EdgeHash(src); IdxType delta = 1; IdxType old_cnt = static_cast(-1); while (true) { IdxType old_src = dgl::aten::cuda::AtomicCAS( &_edge_hashmap[start_off + pos].src, static_cast(-1), src); if (old_src == static_cast(-1) || old_src == src) { // first insert old_cnt = dgl::aten::cuda::AtomicAdd( &_edge_hashmap[start_off + pos].cnt, static_cast(1)); if (old_src == static_cast(-1)) { assert(dst_idx < _num_dst); dgl::aten::cuda::AtomicAdd(&_dst_unique_edges[dst_idx], static_cast(1)); } break; } pos = EdgeHash(pos + delta); delta += 1; } return old_cnt; } template inline __device__ IdxType DeviceEdgeHashmap::GetDstCount(const IdxType &dst_idx) { return _dst_unique_edges[dst_idx]; } template inline __device__ IdxType DeviceEdgeHashmap::GetEdgeCount( const IdxType &src, const IdxType &dst_idx) { IdxType start_off = dst_idx * _num_items_each_dst; IdxType pos = EdgeHash(src); IdxType delta = 1; while (_edge_hashmap[start_off + pos].src != src) { pos = EdgeHash(pos + delta); delta += 1; } return _edge_hashmap[start_off + pos].cnt; } template FrequencyHashmap::FrequencyHashmap( int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream, int64_t edge_table_scale) { _ctx = ctx; _stream = stream; num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale); auto device = dgl::runtime::DeviceAPI::Get(_ctx); auto dst_unique_edges = static_cast( device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType))); auto edge_hashmap = static_cast( device->AllocWorkspace(_ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem))); constexpr int BLOCK_SIZE = 256; constexpr int TILE_SIZE = BLOCK_SIZE * 8; dim3 block(BLOCK_SIZE); dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE); cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType)); _init_edge_table<<>>( edge_hashmap, (num_dst * num_items_each_dst)); _device_edge_hashmap = new DeviceEdgeHashmap( num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap); _dst_unique_edges = dst_unique_edges; _edge_hashmap = edge_hashmap; } template FrequencyHashmap::~FrequencyHashmap() { auto device = dgl::runtime::DeviceAPI::Get(_ctx); delete _device_edge_hashmap; _device_edge_hashmap = nullptr; device->FreeWorkspace(_ctx, _dst_unique_edges); _dst_unique_edges = nullptr; device->FreeWorkspace(_ctx, _edge_hashmap); _edge_hashmap = nullptr; } template std::tuple FrequencyHashmap::Topk( const IdxType *src_data, const IdxType *dst_data, DLDataType dtype, const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_pick) { using Idx64Type = int64_t; const int64_t num_dst_nodes = (num_edges / num_edges_per_node); constexpr int BLOCK_SIZE = 256; // XXX: a experienced value, best performance in GV100 constexpr int TILE_SIZE = BLOCK_SIZE * 32; const dim3 block(BLOCK_SIZE); const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE); auto device = dgl::runtime::DeviceAPI::Get(_ctx); const IdxType num_edge_blocks = static_cast(edges_grid.x); IdxType num_unique_edges = 0; // to mark if this position of edges is the first inserting position for _edge_hashmap bool *is_first_position = static_cast( device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges))); CUDA_CALL(cudaMemset(is_first_position, 0, sizeof(bool) * (num_edges))); // double space to use ExclusiveSum auto edge_blocks_prefix_data = static_cast( device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1))); IdxType *edge_blocks_prefix = edge_blocks_prefix_data; IdxType *edge_blocks_prefix_alternate = (edge_blocks_prefix_data + (num_edge_blocks + 1)); // triple space to use ExclusiveSum and unique_output_offsets auto num_unique_each_node_data = static_cast( device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1))); IdxType *num_unique_each_node = num_unique_each_node_data; IdxType *num_unique_each_node_alternate = (num_unique_each_node_data + (num_dst_nodes + 1)); IdxType *unique_output_offsets = (num_unique_each_node_data + 2 * (num_dst_nodes + 1)); // 1. Scan the all edges and count the unique edges and unique edges for each dst node _count_frequency<<>>( src_data, num_edges, num_edges_per_node, edge_blocks_prefix, is_first_position, *_device_edge_hashmap); // 2. Compact the unique edges frequency // 2.1 ExclusiveSum the edge_blocks_prefix void *d_temp_storage = nullptr; size_t temp_storage_bytes = 0; CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1)); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1)); device->FreeWorkspace(_ctx, d_temp_storage); std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate); device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0, sizeof(num_unique_edges), _ctx, DGLContext{kDLCPU, 0}, dtype, _stream); device->StreamSync(_ctx, _stream); // 2.2 Allocate the data of unique edges and frequency // double space to use SegmentedRadixSort auto unique_src_edges_data = static_cast( device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges))); IdxType *unique_src_edges = unique_src_edges_data; IdxType *unique_src_edges_alternate = unique_src_edges_data + num_unique_edges; // double space to use SegmentedRadixSort auto unique_frequency_data = static_cast( device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges))); Idx64Type *unique_frequency = unique_frequency_data; Idx64Type *unique_frequency_alternate = unique_frequency_data + num_unique_edges; // 2.3 Compact the unique edges and their frequency _compact_frequency<<>>( src_data, dst_data, num_edges, num_edges_per_node, edge_blocks_prefix, is_first_position, num_unique_each_node, unique_src_edges, unique_frequency, *_device_edge_hashmap); // 3. SegmentedRadixSort the unique edges and unique_frequency // 3.1 ExclusiveSum the num_unique_each_node d_temp_storage = nullptr; temp_storage_bytes = 0; CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1)); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1)); device->FreeWorkspace(_ctx, d_temp_storage); // 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency // Create a set of DoubleBuffers to wrap pairs of device pointers cub::DoubleBuffer d_unique_frequency(unique_frequency, unique_frequency_alternate); cub::DoubleBuffer d_unique_src_edges(unique_src_edges, unique_src_edges_alternate); // Determine temporary device storage requirements d_temp_storage = nullptr; temp_storage_bytes = 0; // the DeviceRadixSort is faster than DeviceSegmentedRadixSort, // especially when num_dst_nodes is large (about ~10000) if (dtype.bits == 32) { CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_unique_frequency, d_unique_src_edges, num_unique_edges)); } else { CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, num_unique_each_node_alternate, num_unique_each_node_alternate + 1)); } d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); if (dtype.bits == 32) { CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_unique_frequency, d_unique_src_edges, num_unique_edges)); } else { CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, num_unique_each_node_alternate, num_unique_each_node_alternate + 1)); } device->FreeWorkspace(_ctx, d_temp_storage); // 4. Get the final pick number for each dst node // 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2; const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE); _get_pick_num<<>>( num_unique_each_node, num_pick, num_dst_nodes); // 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets // use unique_output_offsets; d_temp_storage = nullptr; temp_storage_bytes = 0; CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, num_unique_each_node, unique_output_offsets, num_dst_nodes + 1)); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, num_unique_each_node, unique_output_offsets, num_dst_nodes + 1)); device->FreeWorkspace(_ctx, d_temp_storage); // 5. Pick the data to result IdxType num_output = 0; device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0, sizeof(num_output), _ctx, DGLContext{kDLCPU, 0}, dtype, _stream); device->StreamSync(_ctx, _stream); IdArray res_src = IdArray::Empty({static_cast(num_output)}, dtype, _ctx); IdArray res_dst = IdArray::Empty({static_cast(num_output)}, dtype, _ctx); IdArray res_cnt = IdArray::Empty({static_cast(num_output)}, dtype, _ctx); _pick_data<<>>( d_unique_frequency.Current(), d_unique_src_edges.Current(), num_unique_each_node_alternate, dst_data, num_edges_per_node, num_dst_nodes, num_edges, unique_output_offsets, res_src.Ptr(), res_dst.Ptr(), res_cnt.Ptr()); device->FreeWorkspace(_ctx, is_first_position); device->FreeWorkspace(_ctx, edge_blocks_prefix_data); device->FreeWorkspace(_ctx, num_unique_each_node_data); device->FreeWorkspace(_ctx, unique_src_edges_data); device->FreeWorkspace(_ctx, unique_frequency_data); return std::make_tuple(res_src, res_dst, res_cnt); } template class FrequencyHashmap; template class FrequencyHashmap; }; // namespace impl }; // namespace sampling }; // namespace dgl