csc_sampling_graph.cc 2.05 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file csc_sampling_graph.cc
4
5
6
 * @brief Source file of sampling graph.
 */

7
8
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
9

10
11
12
13
namespace graphbolt {
namespace sampling {

CSCSamplingGraph::CSCSamplingGraph(
14
15
16
17
    torch::Tensor& indptr, torch::Tensor& indices,
    const torch::optional<torch::Tensor>& node_type_offset,
    const torch::optional<torch::Tensor>& type_per_edge)
    : indptr_(indptr),
18
      indices_(indices),
19
20
      node_type_offset_(node_type_offset),
      type_per_edge_(type_per_edge) {
21
22
23
24
25
26
  TORCH_CHECK(indptr.dim() == 1);
  TORCH_CHECK(indices.dim() == 1);
  TORCH_CHECK(indptr.device() == indices.device());
}

c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::FromCSC(
27
28
29
30
31
32
33
34
35
36
37
    torch::Tensor indptr, torch::Tensor indices,
    const torch::optional<torch::Tensor>& node_type_offset,
    const torch::optional<torch::Tensor>& type_per_edge) {
  if (node_type_offset.has_value()) {
    auto& offset = node_type_offset.value();
    TORCH_CHECK(offset.dim() == 1);
  }
  if (type_per_edge.has_value()) {
    TORCH_CHECK(type_per_edge.value().dim() == 1);
    TORCH_CHECK(type_per_edge.value().size(0) == indices.size(0));
  }
38
39

  return c10::make_intrusive<CSCSamplingGraph>(
40
      indptr, indices, node_type_offset, type_per_edge);
41
42
}

43
void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
44
45
  const int64_t magic_num =
      read_from_archive(archive, "CSCSamplingGraph/magic_num").toInt();
46
47
48
  TORCH_CHECK(
      magic_num == kCSCSamplingGraphSerializeMagic,
      "Magic numbers mismatch when loading CSCSamplingGraph.");
49
50
  indptr_ = read_from_archive(archive, "CSCSamplingGraph/indptr").toTensor();
  indices_ = read_from_archive(archive, "CSCSamplingGraph/indices").toTensor();
51
52
53
}

void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
54
  archive.write("CSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic);
55
56
57
58
  archive.write("CSCSamplingGraph/indptr", indptr_);
  archive.write("CSCSamplingGraph/indices", indices_);
}

59
60
}  // namespace sampling
}  // namespace graphbolt