"vscode:/vscode.git/clone" did not exist on "d2ef24335c3483b242b328b0db26d3084ac40f2f"
csc_sampling_graph.h 7.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
/**
 *  Copyright (c) 2023 by Contributors
 * @file graphbolt/include/csc_sampling_graph.h
 * @brief Header file of csc sampling graph.
 */

#include <torch/custom_class.h>
#include <torch/torch.h>

#include <string>
#include <vector>

namespace graphbolt {
namespace sampling {

using StringList = std::vector<std::string>;

/**
 * @brief Structure representing heterogeneous information about a graph.
 *
 * Example usage:
 *
 * Suppose the graph has 6 edges
 * node_offset = [0, 2, 4]
 * edge_types = [0, 1, 0, 2, 1, 2]
 * HeteroInfo info({"A", "B", "C"}, {"X", "Y", "Z"}, node_offset, edge_types);
 *
 * This example creates a `HeteroInfo` object with three node types ("A", "B",
 * "C") and three edge types ("X", "Y", "Z"). The `node_offset` tensor
 * represents the offset array of node type, the given array indicates that node
 * [0, 2) has type "A", [2, 4) has type "B", and [4, 6) has type "C". And the
 * `edge_types` tensor represents the type id of each edge.
 */
struct HeteroInfo {
  /**
   * @brief Constructs a new `HeteroInfo` object.
   * @param ntypes List of node types in the graph.
   * @param etypes List of edge types in the graph.
   * @param node_type_offset Offset array of node type. It is assumed that nodes
   * of same type have consecutive ids.
   * @param type_per_edge Type id of each edge, where type id is the
   * corresponding index of `edge_types`.
   */
  HeteroInfo(
      const StringList& ntypes, const StringList& etypes,
      torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
      : node_types(ntypes),
        edge_types(etypes),
        node_type_offset(node_type_offset),
        type_per_edge(type_per_edge) {}

52
53
54
  /** @brief Default constructor. */
  HeteroInfo() = default;

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  /** @brief List of node types in the graph.*/
  StringList node_types;

  /** @brief List of edge types in the graph. */
  StringList edge_types;

  /**
   * @brief Offset array of node type. The length of it is equal to number of
   * node types.
   */
  torch::Tensor node_type_offset;

  /**
   * @brief Type id of each edge, where type id is the corresponding index of
   * edge_types. The length of it is equal to the number of edges.
   */
  torch::Tensor type_per_edge;
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

  /**
  ** @brief Magic number to indicate Hetero info version in serialize/
  ** deserialize stages.
  **/
  static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;

  /**
  ** @brief Load hetero info from stream.
  ** @param archive Input stream for deserializing.
  **/
  void Load(torch::serialize::InputArchive& archive);

  /**
  ** @brief Save hetero info to stream.
  ** @param archive Output stream for serializing.
  **/
  void Save(torch::serialize::OutputArchive& archive) const;
90
91
92
93
94
95
96
97
98
};

/**
 * @brief A sampling oriented csc format graph.
 */
class CSCSamplingGraph : public torch::CustomClassHolder {
 public:
  /**
   * @brief Constructor for CSC with data.
peizhou001's avatar
peizhou001 committed
99
   * @param num_nodes The number of nodes in the graph.
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
   * @param indptr The CSC format index pointer array.
   * @param indices The CSC format index array.
   * @param hetero_info Heterogeneous graph information, if present. Nullptr
   * means it is a homogeneous graph.
   */
  CSCSamplingGraph(
      int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
      const std::shared_ptr<HeteroInfo>& hetero_info);

  /**
   * @brief Create a homogeneous CSC graph from tensors of CSC format.
   * @param num_nodes The number of nodes in the graph.
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices);

  /**
   * @brief Create a heterogeneous CSC graph from tensors of CSC format.
peizhou001's avatar
peizhou001 committed
122
   * @param num_nodes The number of nodes in the graph.
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   * @param ntypes A list of node types, if present.
   * @param etypes A list of edge types, if present.
   * @param node_type_offset_ A tensor representing the offset of node types, if
   * present.
   * @param type_per_edge_ A tensor representing the type of each edge, if
   * present.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
      const StringList& ntypes, const StringList& etypes,
      torch::Tensor node_type_offset, torch::Tensor type_per_edge);

  /** @brief Get the number of nodes. */
  int64_t NumNodes() const { return num_nodes_; }

  /** @brief Get the number of edges. */
  int64_t NumEdges() const { return indices_.size(0); }

  /** @brief Get the csc index pointer tensor. */
  const torch::Tensor CSCIndptr() const { return indptr_; }

  /** @brief Get the index tensor. */
  const torch::Tensor Indices() const { return indices_; }

  /** @brief Check if the graph is heterogeneous. */
  inline bool IsHeterogeneous() const { return hetero_info_ != nullptr; }

  /** @brief Get the node type offset tensor for a heterogeneous graph. */
  inline const torch::Tensor NodeTypeOffset() const {
    return hetero_info_->node_type_offset;
  }

  /** @brief Get the list of node types for a heterogeneous graph. */
  inline StringList& NodeTypes() const { return hetero_info_->node_types; }

  /** @brief Get the list of edge types for a heterogeneous graph. */
  inline const StringList& EdgeTypes() const {
    return hetero_info_->edge_types;
  }

  /** @brief Get the edge type tensor for a heterogeneous graph. */
  inline const torch::Tensor TypePerEdge() const {
    return hetero_info_->type_per_edge;
  }

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  /**
  ** @brief Magic number to indicate graph version in serialize/deserialize
  ** stage.
  **/
  static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;

  /**
  ** @brief Load graph from stream.
  ** @param archive Input stream for deserializing.
  **/
  void Load(torch::serialize::InputArchive& archive);

  /**
  ** @brief Save graph to stream.
  ** @param archive Output stream for serializing.
  **/
  void Save(torch::serialize::OutputArchive& archive) const;

190
191
192
193
194
195
196
197
198
199
200
201
202
 private:
  /** @brief The number of nodes of the graph. */
  int64_t num_nodes_ = 0;
  /** @brief CSC format index pointer array. */
  torch::Tensor indptr_;
  /** @brief CSC format index array. */
  torch::Tensor indices_;
  /** @brief Heterogeneous graph information, if present. */
  std::shared_ptr<HeteroInfo> hetero_info_;
};

}  // namespace sampling
}  // namespace graphbolt
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

/**
** @brief Overload stream operator to enable `torch::save()` and `torch.load()`
** for CSCSamplingGraph.
**/
namespace torch {

/**
** @brief Overload input stream operator for CSCSamplingGraph deserialization.
** @param archive Input stream for deserializing.
** @param graph CSCSamplingGraph.
**
** @return archive
**/
inline serialize::InputArchive& operator>>(
    serialize::InputArchive& archive,
    graphbolt::sampling::CSCSamplingGraph& graph);

/**
 * @brief Overload output stream operator for CSCSamplingGraph serialization.
 * @param archive Output stream for serializing.
 * @param graph CSCSamplingGraph.
 *
 * @return archive
 */
inline serialize::OutputArchive& operator<<(
    serialize::OutputArchive& archive,
    const graphbolt::sampling::CSCSamplingGraph& graph);

}  // namespace torch