csc_sampling_graph.h 7.08 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file graphbolt/csc_sampling_graph.h
4
5
 * @brief Header file of csc sampling graph.
 */
6
7
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
8
9
10
11
12
13
14
15
16
17

#include <torch/custom_class.h>
#include <torch/torch.h>

#include <string>
#include <vector>

namespace graphbolt {
namespace sampling {

18
19
20
using NodeTypeList = std::vector<std::string>;
using EdgeTypeList =
    std::vector<std::tuple<std::string, std::string, std::string>>;
21
22
23
24
25
26

/**
 * @brief Structure representing heterogeneous information about a graph.
 *
 * Example usage:
 *
27
28
29
30
31
32
 * Suppose the graph has 3 node types, 3 edge types and 6 edges
 * ntypes = {"n1", "n2", "n3"}
 * etypes = {("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2", "e3", "n3")}
 * node_type_offset = [0, 2, 4]
 * type_per_edge = [0, 1, 0, 2, 1, 2]
 * HeteroInfo info(ntypes, etypes, node_type_offset, type_per_edge);
33
 *
34
35
36
37
38
39
40
 * This example creates a `HeteroInfo` object with three node types ("n1", "n2",
 * "n3") and three edge types (("n1", "e1", "n2"), ("n1", "e2", "n3"), ("n2",
 * "e3", "n3")). The `node_type_offset` tensor represents the offset array of
 * node type, the given array indicates that node [0, 2) has type "n1", [2, 4)
 * has type "n2", and [4, 6) has type "n3". And the `type_per_edge` tensor
 * represents the type id of each edge, which is the index in the `etypes`
 * tensor.
41
42
43
44
 */
struct HeteroInfo {
  /**
   * @brief Constructs a new `HeteroInfo` object.
45
46
47
48
   * @param ntypes List of node types in the graph, where each node type is a
   * string.
   * @param etypes List of edge types in the graph, where each edge type is a
   * string triplet `(str, str, str)`.
49
50
51
52
53
54
   * @param node_type_offset Offset array of node type. It is assumed that nodes
   * of same type have consecutive ids.
   * @param type_per_edge Type id of each edge, where type id is the
   * corresponding index of `edge_types`.
   */
  HeteroInfo(
55
      const NodeTypeList& ntypes, const EdgeTypeList& etypes,
56
57
58
59
60
61
      torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
      : node_types(ntypes),
        edge_types(etypes),
        node_type_offset(node_type_offset),
        type_per_edge(type_per_edge) {}

62
63
64
  /** @brief Default constructor. */
  HeteroInfo() = default;

65
  /** @brief List of node types in the graph.*/
66
  NodeTypeList node_types;
67
68

  /** @brief List of edge types in the graph. */
69
  EdgeTypeList edge_types;
70
71
72
73
74
75
76
77
78
79
80
81

  /**
   * @brief Offset array of node type. The length of it is equal to number of
   * node types.
   */
  torch::Tensor node_type_offset;

  /**
   * @brief Type id of each edge, where type id is the corresponding index of
   * edge_types. The length of it is equal to the number of edges.
   */
  torch::Tensor type_per_edge;
82
83

  /**
84
85
86
   * @brief Magic number to indicate Hetero info version in serialize/
   * deserialize stages.
   */
87
88
89
  static constexpr int64_t kHeteroInfoSerializeMagic = 0xDD2E60F0F6B4A129;

  /**
90
91
92
   * @brief Load hetero info from stream.
   * @param archive Input stream for deserializing.
   */
93
94
95
  void Load(torch::serialize::InputArchive& archive);

  /**
96
97
98
   * @brief Save hetero info to stream.
   * @param archive Output stream for serializing.
   */
99
  void Save(torch::serialize::OutputArchive& archive) const;
100
101
102
103
104
105
106
};

/**
 * @brief A sampling oriented csc format graph.
 */
class CSCSamplingGraph : public torch::CustomClassHolder {
 public:
107
108
109
  /** @brief Default constructor. */
  CSCSamplingGraph() = default;

110
111
  /**
   * @brief Constructor for CSC with data.
peizhou001's avatar
peizhou001 committed
112
   * @param num_nodes The number of nodes in the graph.
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
   * @param indptr The CSC format index pointer array.
   * @param indices The CSC format index array.
   * @param hetero_info Heterogeneous graph information, if present. Nullptr
   * means it is a homogeneous graph.
   */
  CSCSamplingGraph(
      int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
      const std::shared_ptr<HeteroInfo>& hetero_info);

  /**
   * @brief Create a homogeneous CSC graph from tensors of CSC format.
   * @param num_nodes The number of nodes in the graph.
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices);

  /**
   * @brief Create a heterogeneous CSC graph from tensors of CSC format.
peizhou001's avatar
peizhou001 committed
135
   * @param num_nodes The number of nodes in the graph.
136
137
138
139
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   * @param ntypes A list of node types, if present.
   * @param etypes A list of edge types, if present.
Rhett Ying's avatar
Rhett Ying committed
140
   * @param node_type_offset A tensor representing the offset of node types, if
141
   * present.
Rhett Ying's avatar
Rhett Ying committed
142
   * @param type_per_edge A tensor representing the type of each edge, if
143
144
145
146
147
148
   * present.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
149
      const NodeTypeList& ntypes, const EdgeTypeList& etypes,
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
      torch::Tensor node_type_offset, torch::Tensor type_per_edge);

  /** @brief Get the number of nodes. */
  int64_t NumNodes() const { return num_nodes_; }

  /** @brief Get the number of edges. */
  int64_t NumEdges() const { return indices_.size(0); }

  /** @brief Get the csc index pointer tensor. */
  const torch::Tensor CSCIndptr() const { return indptr_; }

  /** @brief Get the index tensor. */
  const torch::Tensor Indices() const { return indices_; }

  /** @brief Check if the graph is heterogeneous. */
  inline bool IsHeterogeneous() const { return hetero_info_ != nullptr; }

  /** @brief Get the node type offset tensor for a heterogeneous graph. */
  inline const torch::Tensor NodeTypeOffset() const {
    return hetero_info_->node_type_offset;
  }

  /** @brief Get the list of node types for a heterogeneous graph. */
173
  inline NodeTypeList& NodeTypes() const { return hetero_info_->node_types; }
174
175

  /** @brief Get the list of edge types for a heterogeneous graph. */
176
  inline const EdgeTypeList& EdgeTypes() const {
177
178
179
180
181
182
183
184
    return hetero_info_->edge_types;
  }

  /** @brief Get the edge type tensor for a heterogeneous graph. */
  inline const torch::Tensor TypePerEdge() const {
    return hetero_info_->type_per_edge;
  }

185
  /**
186
187
188
   * @brief Magic number to indicate graph version in serialize/deserialize
   * stage.
   */
189
190
191
  static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;

  /**
192
193
194
   * @brief Load graph from stream.
   * @param archive Input stream for deserializing.
   */
195
196
197
  void Load(torch::serialize::InputArchive& archive);

  /**
198
199
200
   * @brief Save graph to stream.
   * @param archive Output stream for serializing.
   */
201
202
  void Save(torch::serialize::OutputArchive& archive) const;

203
204
205
206
207
208
209
210
211
212
213
214
215
 private:
  /** @brief The number of nodes of the graph. */
  int64_t num_nodes_ = 0;
  /** @brief CSC format index pointer array. */
  torch::Tensor indptr_;
  /** @brief CSC format index array. */
  torch::Tensor indices_;
  /** @brief Heterogeneous graph information, if present. */
  std::shared_ptr<HeteroInfo> hetero_info_;
};

}  // namespace sampling
}  // namespace graphbolt
216

217
#endif  // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_