serialize.h 3.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/**
 *  Copyright (c) 2023 by Contributors
 * @file graphbolt/include/serialize.h
 * @brief Utility functions for serialize and deserialize.
 */

#ifndef GRAPHBOLT_INCLUDE_SERIALIZE_H_
#define GRAPHBOLT_INCLUDE_SERIALIZE_H_

#include <torch/torch.h>

#include <string>
#include <vector>

namespace graphbolt {
namespace utils {

/**
 * @brief Utility function to write to archive.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Data that could be constructed as `torch::IValue`.
 **/
template <typename DataT>
void write_to_archive(
    torch::serialize::OutputArchive& archive, const std::string& key,
    const DataT& data) {
  archive.write(key, data);
}

/**
 * @brief Specialization utility function to save string vector.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Vector of string.
 **/
template <>
void write_to_archive<std::vector<std::string>>(
    torch::serialize::OutputArchive& archive, const std::string& key,
    const std::vector<std::string>& data) {
  archive.write(
      key + "/size", torch::tensor(static_cast<int64_t>(data.size())));
  for (const auto index : c10::irange(data.size())) {
    archive.write(key + "/" + std::to_string(index), data[index]);
  }
}

/**
 * @brief Utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that could be constructed as `torch::IValue`.
 **/
template <typename DataT = torch::IValue>
void read_from_archive(
    torch::serialize::InputArchive& archive, const std::string& key,
    DataT& data) {
  archive.read(key, data);
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `bool`.
 **/
template <>
void read_from_archive<bool>(
    torch::serialize::InputArchive& archive, const std::string& key,
    bool& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toBool();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `int64_t`.
 **/
template <>
void read_from_archive<int64_t>(
    torch::serialize::InputArchive& archive, const std::string& key,
    int64_t& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toInt();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `std::string`.
 **/
template <>
void read_from_archive<std::string>(
    torch::serialize::InputArchive& archive, const std::string& key,
    std::string& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toString();
}

/**
 * @brief Specialization utility function to read from archive.
 * @param archive Input archive.
 * @param key Key name used in reading.
 * @param data Data that is `torch::Tensor`.
 **/
template <>
void read_from_archive<torch::Tensor>(
    torch::serialize::InputArchive& archive, const std::string& key,
    torch::Tensor& data) {
  torch::IValue iv_data;
  archive.read(key, iv_data);
  data = iv_data.toTensor();
}

/**
 * @brief Specialization utility function to read to string vector.
 * @param archive Output archive.
 * @param key Key name used in saving.
 * @param data Vector of string.
 **/
template <>
void read_from_archive<std::vector<std::string>>(
    torch::serialize::InputArchive& archive, const std::string& key,
    std::vector<std::string>& data) {
  int64_t size = 0;
  read_from_archive<int64_t>(archive, key + "/size", size);
  data.resize(static_cast<size_t>(size));
  std::string element;
  for (int64_t index = 0; index < size; ++index) {
    read_from_archive<std::string>(
        archive, key + "/" + std::to_string(index), element);
    data[index] = element;
  }
}

}  // namespace utils
}  // namespace graphbolt

#endif  // GRAPHBOLT_INCLUDE_SERIALIZE_H_