// !!! This is a file automatically generated by hipify!!! #include "hip/hip_runtime.h" #include "hip/hip_bf16.h" /** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/index_select_impl.cu * @brief Index select operator implementation on CUDA. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../random.h" #include "../utils.h" #include "common.h" #include "utils.h" namespace rocprim{ namespace detail{ template<> struct float_bit_mask<__hip_bfloat16> { static constexpr uint16_t sign_bit = 0x8000; static constexpr uint16_t exponent = 0x7F80; static constexpr uint16_t mantissa = 0x007F; using bit_type = uint16_t; }; template<> struct radix_key_codec_base<__hip_bfloat16> : radix_key_codec_floating<__hip_bfloat16, unsigned short> { }; } } #if HIP_VERSION_MAJOR<6 __host__ __device__ bool operator>(const __hip_bfloat16& a, const __hip_bfloat16& b) { return float(a)>float(b); } #endif namespace graphbolt { namespace ops { constexpr int BLOCK_SIZE = 128; inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) { // To match the type of "::atomicCAS", ignore lint warning. using Type = unsigned long long int; // NOLINT static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); return atomicMax(reinterpret_cast(address), static_cast(val)); } inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) { // To match the type of "::atomicCAS", ignore lint warning. using Type = int; // NOLINT static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); return atomicMax(reinterpret_cast(address), static_cast(val)); } /** * @brief Performs neighbor sampling and fills the edge_ids array with * original edge ids if sliced_indptr is valid. If not, then it fills the edge * ids array with numbers upto the node degree. */ template __global__ void _ComputeRandomsNS( const int64_t num_edges, const indptr_t* const sliced_indptr, const indptr_t* const sub_indptr, const indptr_t* const output_indptr, const indices_t* const csr_rows, const uint64_t random_seed, indptr_t* edge_ids) { int64_t i = blockIdx.x * blockDim.x + threadIdx.x; const int stride = gridDim.x * blockDim.x; hiprandStatePhilox4_32_10_t rng; hiprand_init(random_seed, i, 0, &rng); while (i < num_edges) { const auto row_position = csr_rows[i]; const auto row_offset = i - sub_indptr[row_position]; const auto output_offset = output_indptr[row_position]; const auto fanout = output_indptr[row_position + 1] - output_offset; const auto rnd = row_offset < fanout ? row_offset : hiprand(&rng) % (row_offset + 1); if (rnd < fanout) { const indptr_t edge_id = row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0); #if __CUDA_ARCH__ >= 700 ::cuda::atomic_ref a( edge_ids[output_offset + rnd]); a.fetch_max(edge_id, ::cuda::std::memory_order_relaxed); #else AtomicMax(edge_ids + output_offset + rnd, edge_id); #endif // __CUDA_ARCH__ } i += stride; } } /** * @brief Fills the random_arr with random numbers and the edge_ids array with * original edge ids. When random_arr is sorted along with edge_ids, the first * fanout elements of each row gives us the sampled edges. */ template < typename float_t, typename indptr_t, typename indices_t, typename weights_t, typename edge_id_t> __global__ void _ComputeRandoms( const int64_t num_edges, const indptr_t* const sliced_indptr, const indptr_t* const sub_indptr, const indices_t* const csr_rows, const weights_t* const sliced_weights, const indices_t* const indices, const continuous_seed random_seed, float_t* random_arr, // const unsigned long long random_seed, float_t* random_arr, edge_id_t* edge_ids) { int64_t i = blockIdx.x * blockDim.x + threadIdx.x; const int stride = gridDim.x * blockDim.x; hiprandStatePhilox4_32_10_t rng; const auto labor = indices != nullptr; while (i < num_edges) { const auto row_position = csr_rows[i]; const auto row_offset = i - sub_indptr[row_position]; const auto in_idx = sliced_indptr[row_position] + row_offset; const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i); const auto prob = sliced_weights ? sliced_weights[i] : static_cast(1); const auto exp_rnd = -__logf(rnd); const float_t adjusted_rnd = prob > 0 ? static_cast(exp_rnd / prob) : std::numeric_limits::infinity(); random_arr[i] = adjusted_rnd; edge_ids[i] = row_offset; i += stride; } } struct IsPositive { template __host__ __device__ auto operator()(probs_t x) { return x > 0; } }; template struct MinInDegreeFanout { const indptr_t* in_degree; const int64_t* fanouts; size_t num_fanouts; __host__ __device__ auto operator()(int64_t i) { return static_cast( min(static_cast(in_degree[i]), fanouts[i % num_fanouts])); } }; template struct IteratorFunc { indptr_t* indptr; indices_t* indices; __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; } }; template struct AddOffset { indptr_t offset; template __host__ __device__ indptr_t operator()(edge_id_t x) { return x + offset; } }; template struct IteratorFuncAddOffset { indptr_t* indptr; indptr_t* sliced_indptr; indices_t* indices; __host__ __device__ auto operator()(int64_t i) { return thrust::transform_output_iterator{ indices + indptr[i], AddOffset{sliced_indptr[i]}}; } }; template struct SegmentEndFunc { indptr_t* indptr; in_degree_iterator_t in_degree; __host__ __device__ auto operator()(int64_t i) { return indptr[i] + in_degree[i]; } }; template struct SegmentEndFunc_hip { indptr_t* indptr; in_degree_iterator_t in_degree; indptr_t* segment_end; // 存储段结束位置的设备指针 __host__ __device__ void operator()(int64_t i) { segment_end[i] = indptr[i] + in_degree[i]; // 直接写入设备指针 } }; c10::intrusive_ptr SampleNeighbors( torch::Tensor indptr, torch::Tensor indices, torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, bool return_eids, torch::optional type_per_edge, torch::optional probs_or_mask, torch::optional node_type_offset, torch::optional> node_type_to_id, torch::optional> edge_type_to_id, torch::optional random_seed_tensor, float seed2_contribution) { // When seed_offsets.has_value() in the hetero case, we compute the output of // sample_neighbors _convert_to_sampled_subgraph in a fused manner so that // _convert_to_sampled_subgraph only has to perform slices over the returned // indptr and indices tensors to form CSC outputs for each edge type. TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!"); // Assume that indptr, indices, seeds, type_per_edge and probs_or_mask // are all resident on the GPU. If not, it is better to first extract them // before calling this function. auto allocator = cuda::GetAllocator(); auto num_rows = seeds.has_value() ? seeds.value().size(0) : indptr.size(0) - 1; auto fanouts_pinned = torch::empty( fanouts.size(), c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); auto fanouts_pinned_ptr = fanouts_pinned.data_ptr(); for (size_t i = 0; i < fanouts.size(); i++) { fanouts_pinned_ptr[i] = fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits::max(); } // Finally, copy the adjusted fanout values to the device memory. auto fanouts_device = allocator.AllocateStorage(fanouts.size()); CUDA_CALL(hipMemcpyAsync( fanouts_device.get(), fanouts_pinned_ptr, sizeof(int64_t) * fanouts.size(), hipMemcpyHostToDevice, cuda::GetCurrentStream())); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, seeds); auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto max_in_degree = torch::empty( 1, c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true)); AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] { CUB_CALL( DeviceReduce::Max, in_degree.data_ptr(), max_in_degree.data_ptr(), num_rows); })); // Protect access to max_in_degree with a CUDAEvent at::cuda::CUDAEvent max_in_degree_event; max_in_degree_event.record(); torch::optional num_edges; torch::Tensor sub_indptr; if (!seeds.has_value()) { num_edges = indices.size(0); sub_indptr = indptr; } torch::optional sliced_probs_or_mask; if (probs_or_mask.has_value()) { if (seeds.has_value()) { torch::Tensor sliced_probs_or_mask_tensor; std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl( in_degree, sliced_indptr, probs_or_mask.value(), seeds.value(), indptr.size(0) - 2, num_edges); sliced_probs_or_mask = sliced_probs_or_mask_tensor; num_edges = sliced_probs_or_mask_tensor.size(0); } else { sliced_probs_or_mask = probs_or_mask; } } if (fanouts.size() > 1) { torch::Tensor sliced_type_per_edge; if (seeds.has_value()) { std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl( in_degree, sliced_indptr, type_per_edge.value(), seeds.value(), indptr.size(0) - 2, num_edges); } else { sliced_type_per_edge = type_per_edge.value(); } std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero( sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size()); num_rows = sliced_indptr.size(0); num_edges = sliced_type_per_edge.size(0); } // If sub_indptr was not computed in the two code blocks above: if (seeds.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) { sub_indptr = ExclusiveCumSum(in_degree); } const continuous_seed random_seed = [&] { if (random_seed_tensor.has_value()) { return continuous_seed(random_seed_tensor.value(), seed2_contribution); } else { return continuous_seed{RandomEngine::ThreadLocal()->RandInt( static_cast(0), std::numeric_limits::max())}; } }(); auto output_indptr = torch::empty_like(sub_indptr); torch::Tensor picked_eids; torch::Tensor output_indices; AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { using indptr_t = index_t; if (probs_or_mask.has_value()) { // Count nonzero probs into in_degree. GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask.value().scalar_type(), "SampleNeighborsPositiveProbs", ([&] { using probs_t = scalar_t; auto is_nonzero = thrust::make_transform_iterator( sliced_probs_or_mask.value().data_ptr(), IsPositive{}); CUB_CALL( DeviceSegmentedReduce::Sum, is_nonzero, in_degree.data_ptr(), num_rows, sub_indptr.data_ptr(), sub_indptr.data_ptr() + 1); })); } thrust::counting_iterator iota(0); auto sampled_degree = thrust::make_transform_iterator( iota, MinInDegreeFanout{ in_degree.data_ptr(), fanouts_device.get(), fanouts.size()}); // Compute output_indptr. CUB_CALL( DeviceScan::ExclusiveSum, sampled_degree, output_indptr.data_ptr(), num_rows + 1); auto num_sampled_edges = cuda::CopyScalar{output_indptr.data_ptr() + num_rows}; // This operation is placed after num_sampled_edges copy is started to // hide the latency of copy synchronization later. auto coo_rows = ExpandIndptrImpl( sub_indptr, indices.scalar_type(), torch::nullopt, num_edges); num_edges = coo_rows.size(0); // Find the smallest integer type to store the edge id offsets. We synch // the CUDAEvent so that the access is safe. auto compute_num_bits = [&] { max_in_degree_event.synchronize(); return cuda::NumberOfBits(max_in_degree.data_ptr()[0]); }; if (layer || probs_or_mask.has_value()) { const int num_bits = compute_num_bits(); std::array type_bits = {8, 16, 32, 64}; const auto type_index = std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) - type_bits.begin(); std::array types = { torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong}; auto edge_id_dtype = types[type_index]; AT_DISPATCH_INTEGRAL_TYPES( edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] { using edge_id_t = std::make_unsigned_t; TORCH_CHECK( num_bits <= sizeof(edge_id_t) * 8, "Selected edge_id_t must be capable of storing edge_ids."); // Using bfloat16 for random numbers works just as reliably as // float32 and provides around 30% speedup. using rnd_t = hip_bfloat16; auto randoms = allocator.AllocateStorage(num_edges.value()); auto randoms_sorted = allocator.AllocateStorage(num_edges.value()); auto edge_id_segments = allocator.AllocateStorage(num_edges.value()); auto sorted_edge_id_segments = allocator.AllocateStorage(num_edges.value()); AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "SampleNeighborsIndices", ([&] { using indices_t = index_t; auto probs_or_mask_scalar_type = torch::kFloat32; if (probs_or_mask.has_value()) { probs_or_mask_scalar_type = probs_or_mask.value().scalar_type(); } GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask_scalar_type, "SampleNeighborsProbs", ([&] { using probs_t = scalar_t; probs_t* sliced_probs_ptr = nullptr; if (sliced_probs_or_mask.has_value()) { sliced_probs_ptr = sliced_probs_or_mask.value() .data_ptr(); } const indices_t* indices_ptr = layer ? indices.data_ptr() : nullptr; const dim3 block(BLOCK_SIZE); const dim3 grid( (num_edges.value() + BLOCK_SIZE - 1) / BLOCK_SIZE); // Compute row and random number pairs. CUDA_KERNEL_CALL( _ComputeRandoms, grid, block, 0, num_edges.value(), sliced_indptr.data_ptr(), sub_indptr.data_ptr(), coo_rows.data_ptr(), sliced_probs_ptr, indices_ptr, random_seed, randoms.get(), edge_id_segments.get()); })); })); // Sort the random numbers along with edge ids, after // sorting the first fanout elements of each row will // give us the sampled edges. CUB_CALL( DeviceSegmentedSort::SortPairs, randoms.get(), randoms_sorted.get(), edge_id_segments.get(), sorted_edge_id_segments.get(), num_edges.value(), num_rows, sub_indptr.data_ptr(), sub_indptr.data_ptr() + 1); picked_eids = torch::empty( static_cast(num_sampled_edges), sub_indptr.options()); // Need to sort the sampled edges only when fanouts.size() == 1 // since multiple fanout sampling case is automatically going to // be sorted. if (type_per_edge && fanouts.size() == 1) { // Ensuring sort result still ends up in // sorted_edge_id_segments std::swap(edge_id_segments, sorted_edge_id_segments); // ************* // // 分配 segment_end 内存 thrust::device_vector segment_end(num_rows); auto segment_end_ptr = segment_end.data().get(); // 计算段结束位置 thrust::for_each( thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), SegmentEndFunc_hip{ sub_indptr.data_ptr(), sampled_degree, segment_end_ptr}); // ***************** // auto sampled_segment_end_it = thrust::make_transform_iterator( iota, SegmentEndFunc{ sub_indptr.data_ptr(), sampled_degree}); CUB_CALL( DeviceSegmentedSort::SortKeys, edge_id_segments.get(), sorted_edge_id_segments.get(), picked_eids.size(0), num_rows, sub_indptr.data_ptr(), segment_end_ptr); // sub_indptr.data_ptr()+1); } auto input_buffer_it = thrust::make_transform_iterator( iota, IteratorFunc{ sub_indptr.data_ptr(), sorted_edge_id_segments.get()}); auto output_buffer_it = thrust::make_transform_iterator( iota, IteratorFuncAddOffset{ output_indptr.data_ptr(), sliced_indptr.data_ptr(), picked_eids.data_ptr()}); constexpr int64_t max_copy_at_once = std::numeric_limits::max(); // Copy the sampled edge ids into picked_eids tensor. for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { CUB_CALL( DeviceCopy::Batched, input_buffer_it + i, output_buffer_it + i, sampled_degree + i, std::min(num_rows - i, max_copy_at_once)); } })); } else { // Non-weighted neighbor sampling. picked_eids = torch::zeros(num_edges.value(), sub_indptr.options()); const auto sort_needed = type_per_edge && fanouts.size() == 1; const auto sliced_indptr_ptr = sort_needed ? nullptr : sliced_indptr.data_ptr(); const dim3 block(BLOCK_SIZE); const dim3 grid( (std::min(num_edges.value(), static_cast(1 << 20)) + BLOCK_SIZE - 1) / BLOCK_SIZE); AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "SampleNeighborsIndices", ([&] { using indices_t = index_t; // Compute row and random number pairs. CUDA_KERNEL_CALL( _ComputeRandomsNS, grid, block, 0, num_edges.value(), sliced_indptr_ptr, sub_indptr.data_ptr(), output_indptr.data_ptr(), coo_rows.data_ptr(), random_seed.get_seed(0), picked_eids.data_ptr()); })); picked_eids = picked_eids.slice(0, 0, static_cast(num_sampled_edges)); // Need to sort the sampled edges only when fanouts.size() == 1 // since multiple fanout sampling case is automatically going to // be sorted. if (sort_needed) { const int num_bits = compute_num_bits(); std::array type_bits = {8, 15, 31, 63}; const auto type_index = std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) - type_bits.begin(); std::array types = { torch::kByte, torch::kInt16, torch::kInt32, torch::kLong, torch::kLong}; auto edge_id_dtype = types[type_index]; AT_DISPATCH_INTEGRAL_TYPES( edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] { using edge_id_t = scalar_t; TORCH_CHECK( num_bits <= sizeof(edge_id_t) * 8, "Selected edge_id_t must be capable of storing " "edge_ids."); auto picked_offsets = picked_eids.to(edge_id_dtype); auto sorted_offsets = torch::empty_like(picked_offsets); CUB_CALL( DeviceSegmentedSort::SortKeys, picked_offsets.data_ptr(), sorted_offsets.data_ptr(), picked_eids.size(0), num_rows, output_indptr.data_ptr(), output_indptr.data_ptr() + 1); auto edge_id_offsets = ExpandIndptrImpl( output_indptr, picked_eids.scalar_type(), sliced_indptr, picked_eids.size(0)); picked_eids = sorted_offsets.to(picked_eids.scalar_type()) + edge_id_offsets; })); } } output_indices = Gather(indices, picked_eids); })); torch::optional output_type_per_edge; torch::optional edge_offsets; if (type_per_edge && seed_offsets) { const int64_t num_etypes = edge_type_to_id.has_value() ? edge_type_to_id->size() : 1; // If we performed homogenous sampling on hetero graph, we have to look at // type_per_edge of sampled edges and determine the offsets of different // sampled etypes and convert to fused hetero indptr representation. if (fanouts.size() == 1) { output_type_per_edge = Gather(*type_per_edge, picked_eids); torch::Tensor output_in_degree, sliced_output_indptr; sliced_output_indptr = output_indptr.slice(0, 0, output_indptr.size(0) - 1); std::tie(output_indptr, output_in_degree, sliced_output_indptr) = SliceCSCIndptrHetero( output_indptr, output_type_per_edge.value(), sliced_output_indptr, num_etypes); // We use num_rows to hold num_seeds * num_etypes. So, it needs to be // updated when sampling with a single fanout value when the graph is // heterogenous. num_rows = sliced_output_indptr.size(0); } // Here, we check what are the dst node types for the given seeds so that // we can compute the output indptr space later. std::vector etype_id_to_dst_ntype_id(num_etypes); // Here, we check what are the src node types for the given seeds so that // we can subtract source node offset from indices later. auto etype_id_to_src_ntype_id = torch::empty( 2 * num_etypes, c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); auto etype_id_to_src_ntype_id_ptr = etype_id_to_src_ntype_id.data_ptr(); for (auto& etype_and_id : edge_type_to_id.value()) { auto etype = etype_and_id.key(); auto id = etype_and_id.value(); auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype); etype_id_to_dst_ntype_id[id] = node_type_to_id->at(dst_type); etype_id_to_src_ntype_id_ptr[2 * id] = etype_id_to_src_ntype_id_ptr[2 * id + 1] = node_type_to_id->at(src_type); } auto indices_offsets_device = torch::empty( etype_id_to_src_ntype_id.size(0), output_indices.options().dtype(torch::kLong)); AT_DISPATCH_INDEX_TYPES( node_type_offset->scalar_type(), "SampleNeighborsNodeTypeOffset", ([&] { THRUST_CALL( gather, etype_id_to_src_ntype_id_ptr, etype_id_to_src_ntype_id_ptr + etype_id_to_src_ntype_id.size(0), node_type_offset->data_ptr(), indices_offsets_device.data_ptr()); })); // For each edge type, we compute the start and end offsets to index into // indptr to form the final output_indptr. auto indptr_offsets = torch::empty( num_etypes * 2, c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); auto indptr_offsets_ptr = indptr_offsets.data_ptr(); // We compute the indptr offsets here, right now, output_indptr is of size // # seeds * num_etypes + 1. We can simply take slices to get correct output // indptr. The final output_indptr is same as current indptr except that // some intermediate values are removed to change the node ids space from // all of the seed vertices to the node id space of the dst node type of // each edge type. for (int i = 0; i < num_etypes; i++) { indptr_offsets_ptr[2 * i] = num_rows / num_etypes * i + seed_offsets->at(etype_id_to_dst_ntype_id[i]); indptr_offsets_ptr[2 * i + 1] = num_rows / num_etypes * i + seed_offsets->at(etype_id_to_dst_ntype_id[i] + 1); } auto permutation = torch::arange( 0, num_rows * num_etypes, num_etypes, output_indptr.options()); permutation = permutation.remainder(num_rows) + permutation.div(num_rows, "floor"); // This permutation, when applied sorts the sampled edges with respect to // edge types. auto [output_in_degree, sliced_output_indptr] = SliceCSCIndptr(output_indptr, permutation); std::tie(output_indptr, picked_eids) = IndexSelectCSCImpl( output_in_degree, sliced_output_indptr, picked_eids, permutation, num_rows - 1, picked_eids.size(0)); edge_offsets = torch::empty( num_etypes * 2, c10::TensorOptions() .dtype(output_indptr.scalar_type()) .pinned_memory(true)); auto edge_offsets_device = torch::empty(num_etypes * 2, output_indptr.options()); at::cuda::CUDAEvent edge_offsets_event; AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsEdgeOffsets", ([&] { auto edge_offsets_pinned_device_pair = thrust::make_transform_output_iterator( thrust::make_zip_iterator( edge_offsets->data_ptr(), edge_offsets_device.data_ptr()), [=] __device__(index_t x) { return thrust::make_tuple(x, x); }); THRUST_CALL( gather, indptr_offsets_ptr, indptr_offsets_ptr + indptr_offsets.size(0), output_indptr.data_ptr(), edge_offsets_pinned_device_pair); })); edge_offsets_event.record(); auto indices_offset_subtract = ExpandIndptrImpl( edge_offsets_device, indices.scalar_type(), indices_offsets_device, output_indices.size(0)); // The output_indices is permuted here. std::tie(output_indptr, output_indices) = IndexSelectCSCImpl( output_in_degree, sliced_output_indptr, output_indices, permutation, num_rows - 1, output_indices.size(0)); output_indices -= indices_offset_subtract; auto output_indptr_offsets = torch::empty( num_etypes * 2, c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); auto output_indptr_offsets_ptr = output_indptr_offsets.data_ptr(); std::vector indptr_list; for (int i = 0; i < num_etypes; i++) { indptr_list.push_back(output_indptr.slice( 0, indptr_offsets_ptr[2 * i], indptr_offsets_ptr[2 * i + 1] + 1)); output_indptr_offsets_ptr[2 * i] = i == 0 ? 0 : output_indptr_offsets_ptr[2 * i - 1]; output_indptr_offsets_ptr[2 * i + 1] = output_indptr_offsets_ptr[2 * i] + indptr_list.back().size(0); } auto output_indptr_offsets_device = torch::empty( output_indptr_offsets.size(0), output_indptr.options().dtype(torch::kLong)); THRUST_CALL( copy_n, output_indptr_offsets_ptr, output_indptr_offsets.size(0), output_indptr_offsets_device.data_ptr()); // We form the final output indptr by concatenating pieces for different // edge types. output_indptr = torch::cat(indptr_list); auto indptr_offset_subtract = ExpandIndptrImpl( output_indptr_offsets_device, indptr.scalar_type(), edge_offsets_device, output_indptr.size(0)); output_indptr -= indptr_offset_subtract; edge_offsets_event.synchronize(); // We read the edge_offsets here, they are in pairs but we don't need it to // be in pairs. So we remove the duplicate information from it and turn it // into a real offsets array. AT_DISPATCH_INDEX_TYPES( indptr.scalar_type(), "SampleNeighborsEdgeOffsetsCheck", ([&] { auto edge_offsets_ptr = edge_offsets->data_ptr(); TORCH_CHECK(edge_offsets_ptr[0] == 0, "edge_offsets is incorrect."); for (int i = 1; i < num_etypes; i++) { TORCH_CHECK( edge_offsets_ptr[2 * i - 1] == edge_offsets_ptr[2 * i], "edge_offsets is incorrect."); } TORCH_CHECK( edge_offsets_ptr[2 * num_etypes - 1] == picked_eids.size(0), "edge_offsets is incorrect."); for (int i = 0; i < num_etypes; i++) { edge_offsets_ptr[i + 1] = edge_offsets_ptr[2 * i + 1]; } })); edge_offsets = edge_offsets->slice(0, 0, num_etypes + 1); } else { // Convert output_indptr back to homo by discarding intermediate offsets. output_indptr = output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size()); if (type_per_edge) output_type_per_edge = Gather(*type_per_edge, picked_eids); } torch::optional subgraph_reverse_edge_ids = torch::nullopt; if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); return c10::make_intrusive( output_indptr, output_indices, seeds, torch::nullopt, subgraph_reverse_edge_ids, output_type_per_edge, edge_offsets); } } // namespace ops } // namespace graphbolt