csc_sampling_graph.cc 2.87 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file csc_sampling_graph.cc
4
5
6
 * @brief Source file of sampling graph.
 */

7
8
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
9

10
11
12
13
namespace graphbolt {
namespace sampling {

CSCSamplingGraph::CSCSamplingGraph(
14
15
16
17
    torch::Tensor& indptr, torch::Tensor& indices,
    const torch::optional<torch::Tensor>& node_type_offset,
    const torch::optional<torch::Tensor>& type_per_edge)
    : indptr_(indptr),
18
      indices_(indices),
19
20
      node_type_offset_(node_type_offset),
      type_per_edge_(type_per_edge) {
21
22
23
24
25
26
  TORCH_CHECK(indptr.dim() == 1);
  TORCH_CHECK(indices.dim() == 1);
  TORCH_CHECK(indptr.device() == indices.device());
}

c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
27
28
29
30
31
32
33
34
35
36
37
    torch::Tensor indptr, torch::Tensor indices,
    const torch::optional<torch::Tensor>& node_type_offset,
    const torch::optional<torch::Tensor>& type_per_edge) {
  if (node_type_offset.has_value()) {
    auto& offset = node_type_offset.value();
    TORCH_CHECK(offset.dim() == 1);
  }
  if (type_per_edge.has_value()) {
    TORCH_CHECK(type_per_edge.value().dim() == 1);
    TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
  }
38
39

  return c10::make_intrusive<CSCSamplingGraph>(
40
      indptr, indices, node_type_offset, type_per_edge);
41
42
}

43
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
44
45
  const int64_t magic_num =
      read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
46
47
48
  TORCH_CHECK(
      magic_num == kCSCSamplingGraphSerializeMagic,
      "Magic numbers mismatch when loading CSCSamplingGraph.");
49
50
  indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
  indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
51
52
53
54
55
56
57
58
59
60
61
  if (read_from_archive(archive, "CSCSamplingGraph/has_node_type_offset")
          .toBool()) {
    node_type_offset_ =
        read_from_archive(archive, "CSCSamplingGraph/node_type_offset")
            .toTensor();
  }
  if (read_from_archive(archive, "CSCSamplingGraph/has_type_per_edge")
          .toBool()) {
    type_per_edge_ =
        read_from_archive(archive, "CSCSamplingGraph/type_per_edge").toTensor();
  }
62
63
64
}

void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
65
  archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
66
67
  archive.write("CSCSamplingGraph/indptr", indptr_);
  archive.write("CSCSamplingGraph/indices", indices_);
68
69
70
71
72
73
74
75
76
77
78
  archive.write(
      "CSCSamplingGraph/has_node_type_offset", node_type_offset_.has_value());
  if (node_type_offset_) {
    archive.write(
        "CSCSamplingGraph/node_type_offset", node_type_offset_.value());
  }
  archive.write(
      "CSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value());
  if (type_per_edge_) {
    archive.write("CSCSamplingGraph/type_per_edge", type_per_edge_.value());
  }
79
80
}

81
82
}  // namespace sampling
}  // namespace graphbolt