csc_sampling_graph.cc 4.04 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file csc_sampling_graph.cc
4
5
6
 * @brief Source file of sampling graph.
 */

7
8
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
9

10
11
12
namespace graphbolt {
namespace sampling {

13
void HeteroInfo::Load(torch::serialize::InputArchive& archive) {
14
15
  const int64_t magic_num =
      read_from_archive(archive, "HeteroInfo/magic_num").toInt();
16
17
18
  TORCH_CHECK(
      magic_num == kHeteroInfoSerializeMagic,
      "Magic numbers mismatch when loading HeteroInfo.");
19
20
21
22
23
24
25
26
  node_types = read_from_archive(archive, "HeteroInfo/node_types")
                   .to<decltype(node_types)>();
  edge_types = read_from_archive(archive, "HeteroInfo/edge_types")
                   .to<decltype(edge_types)>();
  node_type_offset =
      read_from_archive(archive, "HeteroInfo/node_type_offset").toTensor();
  type_per_edge =
      read_from_archive(archive, "HeteroInfo/type_per_edge").toTensor();
27
28
29
}

void HeteroInfo::Save(torch::serialize::OutputArchive& archive) const {
30
31
32
33
34
  archive.write("HeteroInfo/magic_num", kHeteroInfoSerializeMagic);
  archive.write("HeteroInfo/node_types", node_types);
  archive.write("HeteroInfo/edge_types", edge_types);
  archive.write("HeteroInfo/node_type_offset", node_type_offset);
  archive.write("HeteroInfo/type_per_edge", type_per_edge);
35
36
}

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
CSCSamplingGraph::CSCSamplingGraph(
    int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
    const std::shared_ptr<HeteroInfo>& hetero_info)
    : num_nodes_(num_nodes),
      indptr_(indptr),
      indices_(indices),
      hetero_info_(hetero_info) {
  TORCH_CHECK(num_nodes >= 0);
  TORCH_CHECK(indptr.dim() == 1);
  TORCH_CHECK(indices.dim() == 1);
  TORCH_CHECK(indptr.size(0) == num_nodes + 1);
  TORCH_CHECK(indptr.device() == indices.device());
}

c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
    int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices) {
  return c10::make_intrusive<CSCSamplingGraph>(
      num_nodes, indptr, indices, nullptr);
}

c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSCWithHeteroInfo(
    int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
59
    const NodeTypeList& ntypes, const EdgeTypeList& etypes,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    torch::Tensor node_type_offset, torch::Tensor type_per_edge) {
  TORCH_CHECK(node_type_offset.size(0) > 0);
  TORCH_CHECK(node_type_offset.dim() == 1);
  TORCH_CHECK(type_per_edge.size(0) > 0);
  TORCH_CHECK(type_per_edge.dim() == 1);
  TORCH_CHECK(node_type_offset.device() == type_per_edge.device());
  TORCH_CHECK(type_per_edge.device() == indices.device());
  TORCH_CHECK(!ntypes.empty());
  TORCH_CHECK(!etypes.empty());
  auto hetero_info = std::make_shared<HeteroInfo>(
      ntypes, etypes, node_type_offset, type_per_edge);

  return c10::make_intrusive<CSCSamplingGraph>(
      num_nodes, indptr, indices, hetero_info);
}

76
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
77
78
  const int64_t magic_num =
      read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
79
80
81
  TORCH_CHECK(
      magic_num == kCSCSamplingGraphSerializeMagic,
      "Magic numbers mismatch when loading CSCSamplingGraph.");
82
83
84
85
86
  num_nodes_ = read_from_archive(archive, "CSCSamplingGraph/num_nodes").toInt();
  indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
  indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
  const bool is_heterogeneous =
      read_from_archive(archive, "CSCSamplingGraph/is_hetero").toBool();
87
88
89
90
91
92
93
  if (is_heterogeneous) {
    hetero_info_ = std::make_shared<HeteroInfo>();
    hetero_info_->Load(archive);
  }
}

void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
94
95
  archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
  archive.write("CSCSamplingGraph/num_nodes", num_nodes_);
96
97
  archive.write("CSCSamplingGraph/indptr", indptr_);
  archive.write("CSCSamplingGraph/indices", indices_);
98
  archive.write("CSCSamplingGraph/is_hetero", IsHeterogeneous());
99
100
101
102
103
  if (IsHeterogeneous()) {
    hetero_info_->Save(archive);
  }
}

104
105
}  // namespace sampling
}  // namespace graphbolt