/** * Copyright (c) 2023 by Contributors * @file csc_sampling_graph.cc * @brief Source file of sampling graph. */ #include #include namespace graphbolt { namespace sampling { CSCSamplingGraph::CSCSamplingGraph( torch::Tensor& indptr, torch::Tensor& indices, const torch::optional& node_type_offset, const torch::optional& type_per_edge) : indptr_(indptr), indices_(indices), node_type_offset_(node_type_offset), type_per_edge_(type_per_edge) { TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indices.dim() == 1); TORCH_CHECK(indptr.device() == indices.device()); } c10::intrusive_ptr CSCSamplingGraph::FromCSC( torch::Tensor indptr, torch::Tensor indices, const torch::optional& node_type_offset, const torch::optional& type_per_edge) { if (node_type_offset.has_value()) { auto& offset = node_type_offset.value(); TORCH_CHECK(offset.dim() == 1); } if (type_per_edge.has_value()) { TORCH_CHECK(type_per_edge.value().dim() == 1); TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0)); } return c10::make_intrusive( indptr, indices, node_type_offset, type_per_edge); } void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { const int64_t magic_num = read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt(); TORCH_CHECK( magic_num == kCSCSamplingGraphSerializeMagic, "Magic numbers mismatch when loading CSCSamplingGraph."); indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor(); indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor(); } void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const { archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic); archive.write("CSCSamplingGraph/indptr", indptr_); archive.write("CSCSamplingGraph/indices", indices_); } } // namespace sampling } // namespace graphbolt