csc_sampling_graph.h 5.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
 *  Copyright (c) 2023 by Contributors
 * @file graphbolt/include/csc_sampling_graph.h
 * @brief Header file of csc sampling graph.
 */

#include <torch/custom_class.h>
#include <torch/torch.h>

#include <string>
#include <vector>

namespace graphbolt {
namespace sampling {

using StringList = std::vector<std::string>;

/**
 * @brief Structure representing heterogeneous information about a graph.
 *
 * Example usage:
 *
 * Suppose the graph has 6 edges
 * node_offset = [0, 2, 4]
 * edge_types = [0, 1, 0, 2, 1, 2]
 * HeteroInfo info({"A", "B", "C"}, {"X", "Y", "Z"}, node_offset, edge_types);
 *
 * This example creates a `HeteroInfo` object with three node types ("A", "B",
 * "C") and three edge types ("X", "Y", "Z"). The `node_offset` tensor
 * represents the offset array of node type, the given array indicates that node
 * [0, 2) has type "A", [2, 4) has type "B", and [4, 6) has type "C". And the
 * `edge_types` tensor represents the type id of each edge.
 */
struct HeteroInfo {
  /**
   * @brief Constructs a new `HeteroInfo` object.
   * @param ntypes List of node types in the graph.
   * @param etypes List of edge types in the graph.
   * @param node_type_offset Offset array of node type. It is assumed that nodes
   * of same type have consecutive ids.
   * @param type_per_edge Type id of each edge, where type id is the
   * corresponding index of `edge_types`.
   */
  HeteroInfo(
      const StringList& ntypes, const StringList& etypes,
      torch::Tensor& node_type_offset, torch::Tensor& type_per_edge)
      : node_types(ntypes),
        edge_types(etypes),
        node_type_offset(node_type_offset),
        type_per_edge(type_per_edge) {}

  /** @brief List of node types in the graph.*/
  StringList node_types;

  /** @brief List of edge types in the graph. */
  StringList edge_types;

  /**
   * @brief Offset array of node type. The length of it is equal to number of
   * node types.
   */
  torch::Tensor node_type_offset;

  /**
   * @brief Type id of each edge, where type id is the corresponding index of
   * edge_types. The length of it is equal to the number of edges.
   */
  torch::Tensor type_per_edge;
};

/**
 * @brief A sampling oriented csc format graph.
 */
class CSCSamplingGraph : public torch::CustomClassHolder {
 public:
  /**
   * @brief Constructor for CSC with data.
peizhou001's avatar
peizhou001 committed
78
   * @param num_nodes The number of nodes in the graph.
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
   * @param indptr The CSC format index pointer array.
   * @param indices The CSC format index array.
   * @param hetero_info Heterogeneous graph information, if present. Nullptr
   * means it is a homogeneous graph.
   */
  CSCSamplingGraph(
      int64_t num_nodes, torch::Tensor& indptr, torch::Tensor& indices,
      const std::shared_ptr<HeteroInfo>& hetero_info);

  /**
   * @brief Create a homogeneous CSC graph from tensors of CSC format.
   * @param num_nodes The number of nodes in the graph.
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSC(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices);

  /**
   * @brief Create a heterogeneous CSC graph from tensors of CSC format.
peizhou001's avatar
peizhou001 committed
101
   * @param num_nodes The number of nodes in the graph.
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
   * @param indptr Index pointer array of the CSC.
   * @param indices Indices array of the CSC.
   * @param ntypes A list of node types, if present.
   * @param etypes A list of edge types, if present.
   * @param node_type_offset_ A tensor representing the offset of node types, if
   * present.
   * @param type_per_edge_ A tensor representing the type of each edge, if
   * present.
   *
   * @return CSCSamplingGraph
   */
  static c10::intrusive_ptr<CSCSamplingGraph> FromCSCWithHeteroInfo(
      int64_t num_nodes, torch::Tensor indptr, torch::Tensor indices,
      const StringList& ntypes, const StringList& etypes,
      torch::Tensor node_type_offset, torch::Tensor type_per_edge);

  /** @brief Get the number of nodes. */
  int64_t NumNodes() const { return num_nodes_; }

  /** @brief Get the number of edges. */
  int64_t NumEdges() const { return indices_.size(0); }

  /** @brief Get the csc index pointer tensor. */
  const torch::Tensor CSCIndptr() const { return indptr_; }

  /** @brief Get the index tensor. */
  const torch::Tensor Indices() const { return indices_; }

  /** @brief Check if the graph is heterogeneous. */
  inline bool IsHeterogeneous() const { return hetero_info_ != nullptr; }

  /** @brief Get the node type offset tensor for a heterogeneous graph. */
  inline const torch::Tensor NodeTypeOffset() const {
    return hetero_info_->node_type_offset;
  }

  /** @brief Get the list of node types for a heterogeneous graph. */
  inline StringList& NodeTypes() const { return hetero_info_->node_types; }

  /** @brief Get the list of edge types for a heterogeneous graph. */
  inline const StringList& EdgeTypes() const {
    return hetero_info_->edge_types;
  }

  /** @brief Get the edge type tensor for a heterogeneous graph. */
  inline const torch::Tensor TypePerEdge() const {
    return hetero_info_->type_per_edge;
  }

 private:
  /** @brief The number of nodes of the graph. */
  int64_t num_nodes_ = 0;
  /** @brief CSC format index pointer array. */
  torch::Tensor indptr_;
  /** @brief CSC format index array. */
  torch::Tensor indices_;
  /** @brief Heterogeneous graph information, if present. */
  std::shared_ptr<HeteroInfo> hetero_info_;
};

}  // namespace sampling
}  // namespace graphbolt