/** * Copyright (c) 2023 by Contributors * @file graphbolt/include/csc_sampling_graph.cc * @brief Source file of sampling graph. */ #include "csc_sampling_graph.h" namespace graphbolt { namespace sampling { CSCSamplingGraph::CSCSamplingGraph( int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices, const std::shared_ptr& hetero_info) : num_nodes_(num_nodes), indptr_(indptr), indices_(indices), hetero_info_(hetero_info) { TORCH_CHECK(num_nodes >= 0); TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indptr.size(0) == num_nodes + 1); TORCH_CHECK(indptr.device() == indices.device()); } c10::intrusive_ptr CSCSamplingGraph::FromCSC( int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices) { return c10::make_intrusive( num_nodes, indptr, indices, nullptr); } c10::intrusive_ptr CSCSamplingGraph::FromCSCWithHeteroInfo( int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices, const StringList& ntypes, const StringList& etypes, torch::Tensor node_type_offset, torch::Tensor type_per_edge) { TORCH_CHECK(node_type_offset.size(0) > 0); TORCH_CHECK(node_type_offset.dim() == 1); TORCH_CHECK(type_per_edge.size(0) > 0); TORCH_CHECK(type_per_edge.dim() == 1); TORCH_CHECK(node_type_offset.device() == type_per_edge.device()); TORCH_CHECK(type_per_edge.device() == indices.device()); TORCH_CHECK(!ntypes.empty()); TORCH_CHECK(!etypes.empty()); auto hetero_info = std::make_shared( ntypes, etypes, node_type_offset, type_per_edge); return c10::make_intrusive( num_nodes, indptr, indices, hetero_info); } } // namespace sampling } // namespace graphbolt