/** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/csr_to_coo.cu * @brief CSRToCOO operator implementation on CUDA. */ #include #include #include #include #include #include "./common.h" namespace graphbolt { namespace ops { template struct RepeatIndex { __host__ __device__ auto operator()(indices_t i) { return thrust::make_constant_iterator(i); } }; template struct OutputBufferIndexer { const indptr_t* indptr; indices_t* buffer; __host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; } }; template struct AdjacentDifference { const indptr_t* indptr; __host__ __device__ auto operator()(int64_t i) { return indptr[i + 1] - indptr[i]; } }; torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype) { const auto num_rows = indptr.size(0) - 1; thrust::counting_iterator iota(0); return AT_DISPATCH_INTEGRAL_TYPES( indptr.scalar_type(), "CSRToCOOIndptr", ([&] { using indptr_t = scalar_t; auto indptr_ptr = indptr.data_ptr(); auto num_edges = cuda::CopyScalar{indptr.data_ptr() + num_rows}; auto csr_rows = torch::empty( static_cast(num_edges), indptr.options().dtype(output_dtype)); AT_DISPATCH_INTEGRAL_TYPES( output_dtype, "CSRToCOOIndices", ([&] { using indices_t = scalar_t; auto csc_rows_ptr = csr_rows.data_ptr(); auto input_buffer = thrust::make_transform_iterator( iota, RepeatIndex{}); auto output_buffer = thrust::make_transform_iterator( iota, OutputBufferIndexer{ indptr_ptr, csc_rows_ptr}); auto buffer_sizes = thrust::make_transform_iterator( iota, AdjacentDifference{indptr_ptr}); constexpr int64_t max_copy_at_once = std::numeric_limits::max(); for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { CUB_CALL( DeviceCopy::Batched, input_buffer + i, output_buffer + i, buffer_sizes + i, std::min(num_rows - i, max_copy_at_once)); } })); return csr_rows; })); } } // namespace ops } // namespace graphbolt