// !!! This is a file automatically generated by hipify!!! #include "hip/hip_runtime.h" #include "hip/hip_bf16.h" /** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/index_select_impl.cu * @brief Index select operator implementation on CUDA. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../random.h" #include "common.h" #include "utils.h" namespace rocprim{ namespace detail{ template<> struct float_bit_mask<__hip_bfloat16> { static constexpr uint16_t sign_bit = 0x8000; static constexpr uint16_t exponent = 0x7F80; static constexpr uint16_t mantissa = 0x007F; using bit_type = uint16_t; }; template<> struct radix_key_codec_base<__hip_bfloat16> : radix_key_codec_floating<__hip_bfloat16, unsigned short> { }; } } #if HIP_VERSION_MAJOR<6 __host__ __device__ bool operator>(const __hip_bfloat16& a, const __hip_bfloat16& b) { return float(a)>float(b); } #endif namespace graphbolt { namespace ops { constexpr int BLOCK_SIZE = 128; /** * @brief Fills the random_arr with random numbers and the edge_ids array with * original edge ids. When random_arr is sorted along with edge_ids, the first * fanout elements of each row gives us the sampled edges. */ template < typename float_t, typename indptr_t, typename indices_t, typename weights_t, typename edge_id_t> __global__ void _ComputeRandoms( const int64_t num_edges, const indptr_t* const sliced_indptr, const indptr_t* const sub_indptr, const indices_t* const csr_rows, const weights_t* const sliced_weights, const indices_t* const indices, const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) { int64_t i = blockIdx.x * blockDim.x + threadIdx.x; const int stride = gridDim.x * blockDim.x; hiprandStatePhilox4_32_10_t rng; const auto labor = indices != nullptr; if (!labor) { hiprand_init(random_seed, i, 0, &rng); } while (i < num_edges) { const auto row_position = csr_rows[i]; const auto row_offset = i - sub_indptr[row_position]; const auto in_idx = sliced_indptr[row_position] + row_offset; if (labor) { constexpr uint64_t kCurandSeed = 999961; hiprand_init(kCurandSeed, random_seed, indices[in_idx], &rng); } const auto rnd = hiprand_uniform(&rng); const auto prob = sliced_weights ? sliced_weights[i] : static_cast(1); const auto exp_rnd = -__logf(rnd); const float_t adjusted_rnd = prob > 0 ? static_cast(exp_rnd / prob) : std::numeric_limits::infinity(); random_arr[i] = adjusted_rnd; edge_ids[i] = row_offset; i += stride; } } struct IsPositive { template __host__ __device__ auto operator()(probs_t x) { return x > 0; } }; template struct MinInDegreeFanout { const indptr_t* in_degree; const int64_t* fanouts; size_t num_fanouts; __host__ __device__ auto operator()(int64_t i) { return static_cast( min(static_cast(in_degree[i]), fanouts[i % num_fanouts])); } }; template struct IteratorFunc { indptr_t* indptr; indices_t* indices; __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; } }; template struct AddOffset { indptr_t offset; template __host__ __device__ indptr_t operator()(edge_id_t x) { return x + offset; } }; template struct IteratorFuncAddOffset { indptr_t* indptr; indptr_t* sliced_indptr; indices_t* indices; __host__ __device__ auto operator()(int64_t i) { return thrust::transform_output_iterator{ indices + indptr[i], AddOffset{sliced_indptr[i]}}; } }; template struct SegmentEndFunc { indptr_t* indptr; in_degree_iterator_t in_degree; __host__ __device__ auto operator()(int64_t i) { return indptr[i] + in_degree[i]; } }; c10::intrusive_ptr SampleNeighbors( torch::Tensor indptr, torch::Tensor indices, torch::optional nodes, const std::vector& fanouts, bool replace, bool layer, bool return_eids, torch::optional type_per_edge, torch::optional probs_or_mask) { TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!"); // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask // are all resident on the GPU. If not, it is better to first extract them // before calling this function. auto allocator = cuda::GetAllocator(); auto num_rows = nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1; auto fanouts_pinned = torch::empty( fanouts.size(), c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); auto fanouts_pinned_ptr = fanouts_pinned.data_ptr(); for (size_t i = 0; i < fanouts.size(); i++) { fanouts_pinned_ptr[i] = fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits::max(); } // Finally, copy the adjusted fanout values to the device memory. auto fanouts_device = allocator.AllocateStorage(fanouts.size()); CUDA_CALL(hipMemcpyAsync( fanouts_device.get(), fanouts_pinned_ptr, sizeof(int64_t) * fanouts.size(), hipMemcpyHostToDevice, cuda::GetCurrentStream())); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto max_in_degree = torch::empty( 1, c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true)); AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] { CUB_CALL( DeviceReduce::Max, in_degree.data_ptr(), max_in_degree.data_ptr(), num_rows); })); // Protect access to max_in_degree with a CUDAEvent at::cuda::CUDAEvent max_in_degree_event; max_in_degree_event.record(); torch::optional num_edges; torch::Tensor sub_indptr; if (!nodes.has_value()) { num_edges = indices.size(0); sub_indptr = indptr; } torch::optional sliced_probs_or_mask; if (probs_or_mask.has_value()) { if (nodes.has_value()) { torch::Tensor sliced_probs_or_mask_tensor; std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl( in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(), indptr.size(0) - 2, num_edges); sliced_probs_or_mask = sliced_probs_or_mask_tensor; num_edges = sliced_probs_or_mask_tensor.size(0); } else { sliced_probs_or_mask = probs_or_mask; } } if (fanouts.size() > 1) { torch::Tensor sliced_type_per_edge; if (nodes.has_value()) { std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl( in_degree, sliced_indptr, type_per_edge.value(), nodes.value(), indptr.size(0) - 2, num_edges); } else { sliced_type_per_edge = type_per_edge.value(); } std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero( sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size()); num_rows = sliced_indptr.size(0); num_edges = sliced_type_per_edge.size(0); } // If sub_indptr was not computed in the two code blocks above: if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) { sub_indptr = ExclusiveCumSum(in_degree); } auto coo_rows = ExpandIndptrImpl( sub_indptr, indices.scalar_type(), torch::nullopt, num_edges); num_edges = coo_rows.size(0); const auto random_seed = RandomEngine::ThreadLocal()->RandInt( static_cast(0), std::numeric_limits::max()); auto output_indptr = torch::empty_like(sub_indptr); torch::Tensor picked_eids; torch::Tensor output_indices; torch::optional output_type_per_edge; AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { using indptr_t = index_t; if (probs_or_mask.has_value()) { // Count nonzero probs into in_degree. GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask.value().scalar_type(), "SampleNeighborsPositiveProbs", ([&] { using probs_t = scalar_t; auto is_nonzero = thrust::make_transform_iterator( sliced_probs_or_mask.value().data_ptr(), IsPositive{}); CUB_CALL( DeviceSegmentedReduce::Sum, is_nonzero, in_degree.data_ptr(), num_rows, sub_indptr.data_ptr(), sub_indptr.data_ptr() + 1); })); } thrust::counting_iterator iota(0); auto sampled_degree = thrust::make_transform_iterator( iota, MinInDegreeFanout{ in_degree.data_ptr(), fanouts_device.get(), fanouts.size()}); // Compute output_indptr. CUB_CALL( DeviceScan::ExclusiveSum, sampled_degree, output_indptr.data_ptr(), num_rows + 1); auto num_sampled_edges = cuda::CopyScalar{output_indptr.data_ptr() + num_rows}; // Find the smallest integer type to store the edge id offsets. We synch // the CUDAEvent so that the access is safe. max_in_degree_event.synchronize(); const int num_bits = cuda::NumberOfBits(max_in_degree.data_ptr()[0]); std::array type_bits = {8, 16, 32, 64}; const auto type_index = std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) - type_bits.begin(); std::array types = { torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong}; auto edge_id_dtype = types[type_index]; AT_DISPATCH_INTEGRAL_TYPES( edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] { using edge_id_t = std::make_unsigned_t; TORCH_CHECK( num_bits <= sizeof(edge_id_t) * 8, "Selected edge_id_t must be capable of storing edge_ids."); // Using bfloat16 for random numbers works just as reliably as // float32 and provides around %30 percent speedup. using rnd_t = __hip_bfloat16; auto randoms = allocator.AllocateStorage(num_edges.value()); auto randoms_sorted = allocator.AllocateStorage(num_edges.value()); auto edge_id_segments = allocator.AllocateStorage(num_edges.value()); auto sorted_edge_id_segments = allocator.AllocateStorage(num_edges.value()); AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "SampleNeighborsIndices", ([&] { using indices_t = index_t; auto probs_or_mask_scalar_type = torch::kFloat32; if (probs_or_mask.has_value()) { probs_or_mask_scalar_type = probs_or_mask.value().scalar_type(); } GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask_scalar_type, "SampleNeighborsProbs", ([&] { using probs_t = scalar_t; probs_t* sliced_probs_ptr = nullptr; if (sliced_probs_or_mask.has_value()) { sliced_probs_ptr = sliced_probs_or_mask.value() .data_ptr(); } const indices_t* indices_ptr = layer ? indices.data_ptr() : nullptr; const dim3 block(BLOCK_SIZE); const dim3 grid( (num_edges.value() + BLOCK_SIZE - 1) / BLOCK_SIZE); // Compute row and random number pairs. CUDA_KERNEL_CALL( _ComputeRandoms, grid, block, 0, num_edges.value(), // sliced_indptr.data_ptr(), cuda::getTensorDevicePointer(sliced_indptr), // sub_indptr.data_ptr(), cuda::getTensorDevicePointer(sub_indptr), coo_rows.data_ptr(), sliced_probs_ptr, indices_ptr, random_seed, randoms.get(), edge_id_segments.get()); })); })); // Sort the random numbers along with edge ids, after // sorting the first fanout elements of each row will // give us the sampled edges. CUB_CALL( DeviceSegmentedSort::SortPairs, randoms.get(), randoms_sorted.get(), edge_id_segments.get(), sorted_edge_id_segments.get(), num_edges.value(), num_rows, sub_indptr.data_ptr(), sub_indptr.data_ptr() + 1); picked_eids = torch::empty( static_cast(num_sampled_edges), sub_indptr.options()); // Need to sort the sampled edges only when fanouts.size() == 1 // since multiple fanout sampling case is automatically going to // be sorted. if (type_per_edge && fanouts.size() == 1) { // Ensuring sort result still ends up in sorted_edge_id_segments std::swap(edge_id_segments, sorted_edge_id_segments); auto sampled_segment_end_it = thrust::make_transform_iterator( iota, SegmentEndFunc{ sub_indptr.data_ptr(), sampled_degree}); CUB_CALL( DeviceSegmentedSort::SortKeys, edge_id_segments.get(), sorted_edge_id_segments.get(), picked_eids.size(0), num_rows, sub_indptr.data_ptr(), sub_indptr.data_ptr()+1); } auto input_buffer_it = thrust::make_transform_iterator( iota, IteratorFunc{ cuda::getTensorDevicePointer(sub_indptr), sorted_edge_id_segments.get()}); auto output_buffer_it = thrust::make_transform_iterator( iota, IteratorFuncAddOffset{ cuda::getTensorDevicePointer(output_indptr), cuda::getTensorDevicePointer(sliced_indptr), cuda::getTensorDevicePointer(picked_eids)}); constexpr int64_t max_copy_at_once = std::numeric_limits::max(); // Copy the sampled edge ids into picked_eids tensor. for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { CUB_CALL( DeviceCopy::Batched, input_buffer_it + i, output_buffer_it + i, sampled_degree + i, ::min(num_rows - i, max_copy_at_once)); } })); output_indices = torch::empty( picked_eids.size(0), picked_eids.options().dtype(indices.scalar_type())); // Compute: output_indices = indices.gather(0, picked_eids); AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] { using indices_t = index_t; THRUST_CALL( gather, picked_eids.data_ptr(), picked_eids.data_ptr() + picked_eids.size(0), cuda::getTensorDevicePointer(indices), output_indices.data_ptr()); })); if (type_per_edge) { // output_type_per_edge = type_per_edge.gather(0, picked_eids); // The commented out torch equivalent above does not work when // type_per_edge is on pinned memory. That is why, we have to // reimplement it, similar to the indices gather operation above. auto types = type_per_edge.value(); output_type_per_edge = torch::empty( picked_eids.size(0), picked_eids.options().dtype(types.scalar_type())); AT_DISPATCH_INTEGRAL_TYPES( types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] { THRUST_CALL( gather, picked_eids.data_ptr(), picked_eids.data_ptr() + picked_eids.size(0), types.data_ptr(), output_type_per_edge.value().data_ptr()); })); } })); // Convert output_indptr back to homo by discarding intermediate offsets. output_indptr = output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size()); torch::optional subgraph_reverse_edge_ids = torch::nullopt; if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (!nodes.has_value()) { nodes = torch::arange(indptr.size(0) - 1, indices.options()); } return c10::make_intrusive( output_indptr, output_indices, nodes.value(), torch::nullopt, subgraph_reverse_edge_ids, output_type_per_edge); } } // namespace ops } // namespace graphbolt