// !!! This is a file automatically generated by hipify!!! /** * Copyright (c) 2023 by Contributors * Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek) * @file cuda/expand_indptr.cu * @brief ExpandIndptr operator implementation on CUDA. */ #include #include #include #include #include #include #include #include "common.h" namespace graphbolt { namespace ops { template struct RepeatIndex { const nodes_t* nodes; __host__ __device__ auto operator()(indices_t i) { return thrust::make_constant_iterator(nodes ? nodes[i] : i); } }; template struct OutputBufferIndexer { const indptr_t* indptr; indices_t* buffer; __host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; } }; template struct AdjacentDifference { const indptr_t* indptr; __host__ __device__ auto operator()(int64_t i) { return indptr[i + 1] - indptr[i]; } }; torch::Tensor ExpandIndptrImpl( torch::Tensor indptr, torch::ScalarType dtype, torch::optional nodes, torch::optional output_size) { if (!output_size.has_value()) { output_size = AT_DISPATCH_INTEGRAL_TYPES( indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t { auto indptr_ptr = indptr.data_ptr(); auto output_size = cuda::CopyScalar{indptr_ptr + indptr.size(0) - 1}; return static_cast(output_size); })); } auto csc_rows = torch::empty(output_size.value(), indptr.options().dtype(dtype)); AT_DISPATCH_INTEGRAL_TYPES( indptr.scalar_type(), "ExpandIndptrIndptr", ([&] { using indptr_t = scalar_t; auto indptr_ptr = indptr.data_ptr(); AT_DISPATCH_INTEGRAL_TYPES( dtype, "ExpandIndptrIndices", ([&] { using indices_t = scalar_t; auto csc_rows_ptr = csc_rows.data_ptr(); auto nodes_dtype = nodes ? nodes.value().scalar_type() : dtype; AT_DISPATCH_INTEGRAL_TYPES( nodes_dtype, "ExpandIndptrNodes", ([&] { using nodes_t = scalar_t; auto nodes_ptr = nodes ? nodes.value().data_ptr() : nullptr; thrust::counting_iterator iota(0); auto input_buffer = thrust::make_transform_iterator( iota, RepeatIndex{nodes_ptr}); auto output_buffer = thrust::make_transform_iterator( iota, OutputBufferIndexer{ indptr_ptr, csc_rows_ptr}); auto buffer_sizes = thrust::make_transform_iterator( iota, AdjacentDifference{indptr_ptr}); const auto num_rows = indptr.size(0) - 1; constexpr int64_t max_copy_at_once = std::numeric_limits::max(); for (int64_t i = 0; i < num_rows; i += max_copy_at_once) { CUB_CALL( DeviceCopy::Batched, input_buffer + i, output_buffer + i, buffer_sizes + i, ::min(num_rows - i, max_copy_at_once)); } })); })); })); return csc_rows; } } // namespace ops } // namespace graphbolt