bipartite.h 5.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/bipartite.h
 * \brief Bipartite graph
 */

#ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_

#include <dgl/base_heterograph.h>
11
12
#include <dgl/lazy.h>
#include <dgl/array.h>
13
#include <utility>
14
15
16
17
#include <string>
#include <vector>

#include "../c_api_common.h"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace dgl {

/*!
 * \brief Bipartite graph
 *
 * Bipartite graph is a special type of heterograph which has two types
 * of nodes: "Src" and "Dst". All the edges are from "Src" type nodes to
 * "Dst" type nodes, so there is no edge among nodes of the same type.
 */
class Bipartite : public BaseHeteroGraph {
 public:
  /*! \brief source node group type */
  static constexpr dgl_type_t kSrcVType = 0;
  /*! \brief destination node group type */
  static constexpr dgl_type_t kDstVType = 1;
  /*! \brief edge group type */
  static constexpr dgl_type_t kEType = 0;

37
38
39
40
41
42
  // internal data structure
  class COO;
  class CSR;
  typedef std::shared_ptr<COO> COOPtr;
  typedef std::shared_ptr<CSR> CSRPtr;

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  uint64_t NumVertexTypes() const override {
    return 2;
  }

  uint64_t NumEdgeTypes() const override {
    return 1;
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
    LOG(FATAL) << "The method shouldn't be called for Bipartite graph. "
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  void Clear() override {
    LOG(FATAL) << "Bipartite graph is not mutable.";
  }

  DLContext Context() const override;

  uint8_t NumBits() const override;

  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override;

  uint64_t NumEdges(dgl_type_t etype) const override;

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override;

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override;

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override;

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override;

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override;

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override;

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override;

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override;

  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;

  // creators
  /*! \brief Create a bipartite graph from COO arrays */
  static HeteroGraphPtr CreateFromCOO(int64_t num_src, int64_t num_dst,
      IdArray row, IdArray col);

  /*! \brief Create a bipartite graph from (out) CSR arrays */
  static HeteroGraphPtr CreateFromCSR(
      int64_t num_src, int64_t num_dst,
      IdArray indptr, IdArray indices, IdArray edge_ids);

151
152
  /*! \brief Convert the graph to use the given number of bits for storage */
  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
153

154
155
  /*! \brief Copy the data to another context */
  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
156
157
158
159
160
161
162
163
164
165

  /*! \return Return the in-edge CSR format. Create from other format if not exist. */
  CSRPtr GetInCSR() const;

  /*! \return Return the out-edge CSR format. Create from other format if not exist. */
  CSRPtr GetOutCSR() const;

  /*! \return Return the COO format. Create from other format if not exist. */
  COOPtr GetCOO() const;

166
167
168
169
170
171
172
173
174
175
176
177
  /*! \return Return the in-edge CSR in the matrix form */
  aten::CSRMatrix GetInCSRMatrix() const;

  /*! \return Return the out-edge CSR in the matrix form */
  aten::CSRMatrix GetOutCSRMatrix() const;

  /*! \return Return the COO matrix form */
  aten::COOMatrix GetCOOMatrix() const;

 private:
  Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
  /*! \return Return any existing format. */
  HeteroGraphPtr GetAny() const;

  // Graph stored in different format. We use an on-demand strategy: the format is
  // only materialized if the operation that suitable for it is invoked.
  /*! \brief CSR graph that stores reverse edges */
  CSRPtr in_csr_;
  /*! \brief CSR representation */
  CSRPtr out_csr_;
  /*! \brief COO representation */
  COOPtr coo_;
};

};  // namespace dgl

#endif  // DGL_GRAPH_BIPARTITE_H_