serialize.h 3.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
 *  Copyright (c) 2023 by Contributors
 * @file graphbolt/include/serialize.h
 * @brief Utility functions for serialize and deserialize.
 */

#ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_
#define GRAPHBOLT_INCLUDE_SERIALIZE_H_

#include <torch/torch.h>

#include <string>
#include <vector>

namespace graphbolt {
namespace utils {

/**
 * @brief Utility function to write to archive.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Data that could be constructed as `torch::IValue`.
23
 */
24
25
26
27
28
29
30
31
32
33
34
35
template <typename DataT>
void write_to_archive(
    torch::serialize::OutputArchive& archive, const std::string& key,
    const DataT& data) {
  archive.write(key, data);
}

/**
 * @brief Specialization utility function to save string vector.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Vector of string.
36
 */
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
template <>
void write_to_archive<std::vector<std::string>>(
    torch::serialize::OutputArchive& archive, const std::string& key,
    const std::vector<std::string>& data) {
  archive.write(
      key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
  for (const auto index : c10::irange(data.size())) {
    archive.write(key + "/" + std::to_string(index), data[index]);
  }
}

/**
 * @brief Utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that could be constructed as `torch::IValue`.
53
 */
54
55
56
57
58
59
60
61
62
63
64
65
template <typename DataT = torch::IValue>
void read_from_archive(
    torch::serialize::InputArchive& archive, const std::string& key,
    DataT& data) {
  archive.read(key, data);
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `bool`.
66
 */
67
68
69
70
71
72
73
74
75
76
77
78
79
80
template <>
void read_from_archive<bool>(
    torch::serialize::InputArchive& archive, const std::string& key,
    bool& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toBool();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `int64_t`.
81
 */
82
83
84
85
86
87
88
89
90
91
92
93
94
95
template <>
void read_from_archive<int64_t>(
    torch::serialize::InputArchive& archive, const std::string& key,
    int64_t& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toInt();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `std::string`.
96
 */
97
98
99
100
101
102
103
104
105
106
107
108
109
110
template <>
void read_from_archive<std::string>(
    torch::serialize::InputArchive& archive, const std::string& key,
    std::string& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toString();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `torch::Tensor`.
111
 */
112
113
114
115
116
117
118
119
120
121
122
123
124
125
template <>
void read_from_archive<torch::Tensor>(
    torch::serialize::InputArchive& archive, const std::string& key,
    torch::Tensor& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toTensor();
}

/**
 * @brief Specialization utility function to read to string vector.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Vector of string.
126
 */
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
template <>
void read_from_archive<std::vector<std::string>>(
    torch::serialize::InputArchive& archive, const std::string& key,
    std::vector<std::string>& data) {
  int64_t size = 0;
  read_from_archive<int64_t>(archive, key + "/size", size);
  data.resize(static_cast<size_t>(size));
  std::string element;
  for (int64_t index = 0; index < size; ++index) {
    read_from_archive<std::string>(
        archive, key + "/" + std::to_string(index), element);
    data[index] = element;
  }
}

}  // namespace utils
}  // namespace graphbolt

#endif  // GRAPHBOLT_INCLUDE_SERIALIZE_H_