"docs/vscode:/vscode.git/clone" did not exist on "175f53decf58d6f27db01ad8052d1ff919550a81"
serialize.cc 1.14 KB
Newer Older
1
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 * @file graphbolt/src/serialize.cc
 * @brief Source file of serialize.
 */

#include <graphbolt/serialize.h>
8
#include <torch/torch.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

namespace torch {

serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    graphbolt::sampling::CSCSamplingGraph& graph) {
  graph.Load(archive);
  return archive;
}

serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const graphbolt::sampling::CSCSamplingGraph& graph) {
  graph.Save(archive);
  return archive;
}

}  // namespace torch

namespace graphbolt {

c10::intrusive_ptr<sampling::CSCSamplingGraph> LoadCSCSamplingGraph(
    const std::string& filename) {
  auto&& graph = c10::make_intrusive<sampling::CSCSamplingGraph>();
  torch::load(*graph, filename);
  return graph;
}

void SaveCSCSamplingGraph(
    c10::intrusive_ptr<sampling::CSCSamplingGraph> graph,
    const std::string& filename) {
  torch::save(*graph, filename);
}

torch::IValue read_from_archive(
    torch::serialize::InputArchive& archive, const std::string& key) {
  torch::IValue data;
  archive.read(key, data);
  return data;
}

}  // namespace graphbolt