csc_sampling_graph.h 6.84 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file graphbolt/csc_sampling_graph.h
4
5
 * @brief Header file of csc sampling graph.
 */
6
7
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
8

9
#include <graphbolt/sampled_subgraph.h>
10
11
#include <graphbolt/shared_memory.h>
#include <torch/torch.h>
12
13
14
15
16
17
18
19

#include <string>
#include <vector>

namespace graphbolt {
namespace sampling {

/**
20
 * @brief A sampling oriented csc format graph.
21
22
23
 *
 * Example usage:
 *
24
 * Suppose the graph has 3 node types, 3 edge types and 6 edges
25
26
27
 * auto node_type_offset = {0, 2, 4, 6}
 * auto type_per_edge = {0, 1, 0, 2, 1, 2}
 * auto graph = CSCSamplingGraph(..., ..., node_type_offset, type_per_edge)
28
 *
29
30
31
32
 * The `node_type_offset` tensor represents the offset array of node type, the
 * given array indicates that node [0, 2) has type id 0, [2, 4) has type id 1,
 * and [4, 6) has type id 2. And the `type_per_edge` tensor represents the type
 * id of each edge.
33
34
35
 */
class CSCSamplingGraph : public torch::CustomClassHolder {
 public:
36
37
38
  /** @brief Default constructor. */
  CSCSamplingGraph() = default;

39
40
41
42
  /**
   * @brief Constructor for CSC with data.
   * @param indptr The CSC format index pointer array.
   * @param indices The CSC format index array.
43
44
45
46
   * @param node_type_offset A tensor representing the offset of node types, if
   * present.
   * @param type_per_edge A tensor representing the type of each edge, if
   * present.
47
48
   */
  CSCSamplingGraph(
49
      const torch::Tensor& indptr, const torch::Tensor& indices,
50
51
      const torch::optional<torch::Tensor>& node_type_offset,
      const torch::optional<torch::Tensor>& type_per_edge);
52
53
54
55
56

  /**
   * @brief Create a homogeneous CSC graph from tensors of CSC format.
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
Rhett Ying's avatar
Rhett Ying committed
57
   * @param node_type_offset A tensor representing the offset of node types, if
58
   * present.
Rhett Ying's avatar
Rhett Ying committed
59
   * @param type_per_edge A tensor representing the type of each edge, if
60
61
62
63
   * present.
   *
   * @return CSCSamplingGraph
   */
64
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
65
      const torch::Tensor& indptr, const torch::Tensor& indices,
66
67
      const torch::optional<torch::Tensor>& node_type_offset,
      const torch::optional<torch::Tensor>& type_per_edge);
68
69

  /** @brief Get the number of nodes. */
70
  int64_t NumNodes() const { return indptr_.size(0) - 1; }
71
72
73
74
75
76
77
78
79
80
81

  /** @brief Get the number of edges. */
  int64_t NumEdges() const { return indices_.size(0); }

  /** @brief Get the csc index pointer tensor. */
  const torch::Tensor CSCIndptr() const { return indptr_; }

  /** @brief Get the index tensor. */
  const torch::Tensor Indices() const { return indices_; }

  /** @brief Get the node type offset tensor for a heterogeneous graph. */
82
83
  inline const torch::optional<torch::Tensor> NodeTypeOffset() const {
    return node_type_offset_;
84
85
86
  }

  /** @brief Get the edge type tensor for a heterogeneous graph. */
87
88
  inline const torch::optional<torch::Tensor> TypePerEdge() const {
    return type_per_edge_;
89
90
  }

91
  /**
92
93
94
   * @brief Magic number to indicate graph version in serialize/deserialize
   * stage.
   */
95
96
97
  static constexpr int64_t kCSCSamplingGraphSerializeMagic = 0xDD2E60F0F6B4A128;

  /**
98
99
100
   * @brief Load graph from stream.
   * @param archive Input stream for deserializing.
   */
101
102
103
  void Load(torch::serialize::InputArchive& archive);

  /**
104
105
106
   * @brief Save graph to stream.
   * @param archive Output stream for serializing.
   */
107
108
  void Save(torch::serialize::OutputArchive& archive) const;

109
110
111
112
113
114
115
116
117
  /**
   * @brief Return the subgraph induced on the inbound edges of the given nodes.
   * @param nodes Type agnostic node IDs to form the subgraph.
   *
   * @return SampledSubgraph.
   */
  c10::intrusive_ptr<SampledSubgraph> InSubgraph(
      const torch::Tensor& nodes) const;

118
119
120
121
122
123
124
125
126
127
128
129
  /**
   * @brief Sample neighboring edges of the given nodes and return the induced
   * subgraph.
   *
   * @param nodes The nodes from which to sample neighbors.
   *
   * @return An intrusive pointer to a SampledSubgraph object containing the
   * sampled graph's information.
   */
  c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
      const torch::Tensor& nodes) const;

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  /**
   * @brief Copy the graph to shared memory.
   * @param shared_memory_name The name of the shared memory.
   *
   * @return A new CSCSamplingGraph object on shared memory.
   */
  c10::intrusive_ptr<CSCSamplingGraph> CopyToSharedMemory(
      const std::string& shared_memory_name);

  /**
   * @brief Load the graph from shared memory.
   * @param shared_memory_name The name of the shared memory.
   *
   * @return A new CSCSamplingGraph object on shared memory.
   */
  static c10::intrusive_ptr<CSCSamplingGraph> LoadFromSharedMemory(
      const std::string& shared_memory_name);

148
 private:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
  /**
   * @brief Build a CSCSamplingGraph from shared memory tensors.
   *
   * @param shared_memory_tensors A tuple of two share memory objects holding
   * tensor meta information and data respectively, and a vector of optional
   * tensors on shared memory.
   *
   * @return A new CSCSamplingGraph on shared memory.
   */
  static c10::intrusive_ptr<CSCSamplingGraph> BuildGraphFromSharedMemoryTensors(
      std::tuple<
          SharedMemoryPtr, SharedMemoryPtr,
          std::vector<torch::optional<torch::Tensor>>>&& shared_memory_tensors);

163
164
  /** @brief CSC format index pointer array. */
  torch::Tensor indptr_;
165

166
167
  /** @brief CSC format index array. */
  torch::Tensor indices_;
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

  /**
   * @brief Offset array of node type. The length of it is equal to the number
   * of node types + 1. The tensor is in ascending order as nodes of the same
   * type have continuous IDs, and larger node IDs are paired with larger node
   * type IDs. Its first value is 0 and last value is the number of nodes. And
   * nodes with ID between `node_type_offset_[i] ~ node_type_offset_[i+1]` are
   * of type id `i`.
   */
  torch::optional<torch::Tensor> node_type_offset_;

  /**
   * @brief Type id of each edge, where type id is the corresponding index of
   * edge types. The length of it is equal to the number of edges.
   */
  torch::optional<torch::Tensor> type_per_edge_;
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

  /**
   * @brief Maximum number of bytes used to serialize the metadata of the
   * member tensors, including tensor shape and dtype. The constant is estimated
   * by multiplying the number of tensors in this class and the maximum number
   * of bytes used to serialize the metadata of a tensor (4 * 8192 for now).
   */
  static constexpr int64_t SERIALIZED_METAINFO_SIZE_MAX = 32768;

  /**
   * @brief Shared memory used to hold the tensor meta information and data of
   * this class. By storing its shared memory objects, the graph controls the
   * resources of shared memory, which will be released automatically when the
   * graph is destroyed.
   */
  SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_;
200
201
202
203
};

}  // namespace sampling
}  // namespace graphbolt
204

205
#endif  // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_