"docs/source/vscode:/vscode.git/clone" did not exist on "9c790b114326829044bb6fc8ac73869151210910"
unit_graph.h 12.4 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.h
 * \brief UnitGraph graph
5
6
 */

Minjie Wang's avatar
Minjie Wang committed
7
8
#ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_UNIT_GRAPH_H_
9
10

#include <dgl/base_heterograph.h>
11
12
#include <dgl/lazy.h>
#include <dgl/array.h>
13
14
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
15
#include <utility>
16
17
#include <string>
#include <vector>
18
#include <memory>
19
#include <tuple>
20
21

#include "../c_api_common.h"
22
23
24

namespace dgl {

25
class HeteroGraph;
26
27
28
class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr;

29
/*!
Minjie Wang's avatar
Minjie Wang committed
30
 * \brief UnitGraph graph
31
 *
Minjie Wang's avatar
Minjie Wang committed
32
33
34
35
36
37
38
 * UnitGraph graph is a special type of heterograph which
 * (1) Have two types of nodes: "Src" and "Dst". All the edges are
 *     from "Src" type nodes to "Dst" type nodes, so there is no edge among
 *     nodes of the same type. Thus, its metagraph has two nodes and one edge
 *     between them.
 * (2) Have only one type of nodes and edges. Thus, its metagraph has one node
 *     and one self-loop edge.
39
 */
Minjie Wang's avatar
Minjie Wang committed
40
class UnitGraph : public BaseHeteroGraph {
41
 public:
42
43
44
45
46
47
  // internal data structure
  class COO;
  class CSR;
  typedef std::shared_ptr<COO> COOPtr;
  typedef std::shared_ptr<CSR> CSRPtr;

Minjie Wang's avatar
Minjie Wang committed
48
49
50
51
52
53
  inline dgl_type_t SrcType() const {
    return 0;
  }

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
54
55
  }

Minjie Wang's avatar
Minjie Wang committed
56
57
  inline dgl_type_t EdgeType() const {
    return 0;
58
59
60
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
61
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
62
63
64
65
66
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
67
    LOG(FATAL) << "UnitGraph graph is not mutable.";
68
69
70
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
71
    LOG(FATAL) << "UnitGraph graph is not mutable.";
72
73
74
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
75
    LOG(FATAL) << "UnitGraph graph is not mutable.";
76
77
78
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
79
    LOG(FATAL) << "UnitGraph graph is not mutable.";
80
81
  }

82
83
  DLDataType DataType() const override;

84
85
  DLContext Context() const override;

86
87
  bool IsPinned() const override;

88
89
90
91
92
93
94
95
96
97
  uint8_t NumBits() const override;

  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override;

98
  inline std::vector<int64_t> NumVerticesPerType() const override {
99
100
101
102
    std::vector<int64_t> num_nodes_per_type;
    for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
      num_nodes_per_type.push_back(NumVertices(vtype));
    return num_nodes_per_type;
103
104
  }

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  uint64_t NumEdges(dgl_type_t etype) const override;

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override;

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

121
122
123
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override;

  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override;

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override;

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override;

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override;

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override;

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;

149
150
151
  // 32bit version functions, patch for SuccVec
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override;

  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;

  // creators
167
168
169
170
171
172
173
174
175
  /*! \brief Create a graph with no edges */
  static HeteroGraphPtr Empty(
      int64_t num_vtypes, int64_t num_src, int64_t num_dst,
      DLDataType dtype, DLContext ctx) {
    IdArray row = IdArray::Empty({0}, dtype, ctx);
    IdArray col = IdArray::Empty({0}, dtype, ctx);
    return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
  }

Minjie Wang's avatar
Minjie Wang committed
176
177
178
  /*! \brief Create a graph from COO arrays */
  static HeteroGraphPtr CreateFromCOO(
      int64_t num_vtypes, int64_t num_src, int64_t num_dst,
179
180
      IdArray row, IdArray col, bool row_sorted = false,
      bool col_sorted = false, dgl_format_code_t formats = ALL_CODE);
181

182
183
  static HeteroGraphPtr CreateFromCOO(
      int64_t num_vtypes, const aten::COOMatrix& mat,
184
      dgl_format_code_t formats = ALL_CODE);
185

Minjie Wang's avatar
Minjie Wang committed
186
  /*! \brief Create a graph from (out) CSR arrays */
187
  static HeteroGraphPtr CreateFromCSR(
Minjie Wang's avatar
Minjie Wang committed
188
      int64_t num_vtypes, int64_t num_src, int64_t num_dst,
189
      IdArray indptr, IdArray indices, IdArray edge_ids,
190
      dgl_format_code_t formats = ALL_CODE);
191

192
193
  static HeteroGraphPtr CreateFromCSR(
      int64_t num_vtypes, const aten::CSRMatrix& mat,
194
      dgl_format_code_t formats = ALL_CODE);
195
196
197
198
199

  /*! \brief Create a graph from (in) CSC arrays */
  static HeteroGraphPtr CreateFromCSC(
      int64_t num_vtypes, int64_t num_src, int64_t num_dst,
      IdArray indptr, IdArray indices, IdArray edge_ids,
200
      dgl_format_code_t formats = ALL_CODE);
201
202
203

  static HeteroGraphPtr CreateFromCSC(
      int64_t num_vtypes, const aten::CSRMatrix& mat,
204
      dgl_format_code_t formats = ALL_CODE);
205

206
207
  /*! \brief Convert the graph to use the given number of bits for storage */
  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
208

209
  /*! \brief Copy the data to another context */
210
211
  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
                               const DGLStreamHandle &stream = nullptr);
212

213
214
215
216
  /*!
  * \brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
  * \note The graph will be pinned inplace. Behavior depends on the current context,
  *       kDLCPU: will be pinned;
217
  *       IsPinned: directly return;
218
219
220
221
222
223
224
225
  *       kDLGPU: invalid, will throw an error.
  *       The context check is deferred to pinning the NDArray.
  */
  void PinMemory_();

  /*!
  * \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
  * \note The graph will be unpinned inplace. Behavior depends on the current context,
226
  *       IsPinned: will be unpinned;
227
228
229
230
231
  *       others: directly return.
  *       The context check is deferred to unpinning the NDArray.
  */
  void UnpinMemory_();

232
233
234
235
236
237
238
  /*! 
   * \brief Create in-edge CSR format of the unit graph.
   * \param inplace if true and the in-edge CSR format does not exist, the created
   *                format will be cached in this object unless the format is restricted.
   * \return Return the in-edge CSR format. Create from other format if not exist.
   */
  CSRPtr GetInCSR(bool inplace = true) const;
239

240
241
242
243
244
245
246
  /*! 
   * \brief Create out-edge CSR format of the unit graph.
   * \param inplace if true and the out-edge CSR format does not exist, the created
   *                format will be cached in this object unless the format is restricted.
   * \return Return the out-edge CSR format. Create from other format if not exist.
   */
  CSRPtr GetOutCSR(bool inplace = true) const;
247

248
249
250
251
252
253
254
  /*!
   * \brief Create COO format of the unit graph.
   * \param inplace if true and the COO format does not exist, the created
   *                format will be cached in this object unless the format is restricted.
   * \return Return the COO format. Create from other format if not exist.
   */
  COOPtr GetCOO(bool inplace = true) const;
255

256
257
258
259
260
  /*! \return Return the COO matrix form */
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;

  /*! \return Return the in-edge CSC in the matrix form */
  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;
261
262

  /*! \return Return the out-edge CSR in the matrix form */
263
  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
264

265
266
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
    return SelectFormat(preferred_formats);
267
268
  }

269
270
271
272
273
274
275
276
  /*!
   * \brief Return the graph in the given format. Perform format conversion if the
   * requested format does not exist.
   *
   * \return A graph in the requested format.
   */
  HeteroGraphPtr GetFormat(SparseFormat format) const;

277
278
279
  dgl_format_code_t GetCreatedFormats() const override;

  dgl_format_code_t GetAllowedFormats() const override;
280

281
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
282

283
284
285
286
287
288
  /*! \return Load UnitGraph from stream, using CSRMatrix*/
  bool Load(dmlc::Stream* fs);

  /*! \return Save UnitGraph to stream, using CSRMatrix */
  void Save(dmlc::Stream* fs) const;

289
290
291
  /*! \brief Creat a LineGraph of self */
  HeteroGraphPtr LineGraph(bool backtracking) const;

292
293
294
  /*! \return the reversed graph */
  UnitGraphPtr Reverse() const;

295
296
297
298
299
300
301
  /*! \return the simpled (no-multi-edge) graph
   *          the count recording the number of duplicated edges from the original graph.
   *          the edge mapping from the edge IDs of original graph to those of the
   *          returned graph.
   */
  std::tuple<UnitGraphPtr, IdArray, IdArray>ToSimple() const;

302
303
304
305
306
307
  void InvalidateCSR();

  void InvalidateCSC();

  void InvalidateCOO();

308
 private:
309
  friend class Serializer;
310
  friend class HeteroGraph;
311
  friend class ImmutableGraph;
312
  friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
313

314
315
316
  // private empty constructor
  UnitGraph() {}

Minjie Wang's avatar
Minjie Wang committed
317
318
319
320
321
322
323
  /*!
   * \brief constructor
   * \param metagraph metagraph
   * \param in_csr in edge csr
   * \param out_csr out edge csr
   * \param coo coo
   */
324
  UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
325
            dgl_format_code_t formats = ALL_CODE);
326

327
328
  /*!
   * \brief constructor
329
   * \param num_vtypes number of vertex types (1 or 2)
330
331
332
333
334
335
336
337
   * \param metagraph metagraph
   * \param in_csr in edge csr
   * \param out_csr out edge csr
   * \param coo coo
   * \param has_in_csr whether in_csr is valid
   * \param has_out_csr whether out_csr is valid
   * \param has_coo whether coo is valid
   */
338
339
  static HeteroGraphPtr CreateUnitGraphFrom(
      int num_vtypes,
340
341
342
343
344
345
      const aten::CSRMatrix &in_csr,
      const aten::CSRMatrix &out_csr,
      const aten::COOMatrix &coo,
      bool has_in_csr,
      bool has_out_csr,
      bool has_coo,
346
      dgl_format_code_t formats = ALL_CODE);
347

348
349
350
  /*! \return Return any existing format. */
  HeteroGraphPtr GetAny() const;

351
352
353
354
355
356
357
358
359
  /*!
   * \brief Determine which format to use with a preference.
   *
   * If the storage of unit graph is "locked", i.e. no conversion is allowed, then
   * it will return the locked format.
   *
   * Otherwise, it will return whatever DGL thinks is the most appropriate given
   * the arguments.
   */
360
  SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;
361
362
363
364

  /*! \return Whether the graph is hypersparse */
  bool IsHypersparse() const;

365
366
  GraphPtr AsImmutableGraph() const override;

367
368
369
370
371
372
373
374
  // Graph stored in different format. We use an on-demand strategy: the format is
  // only materialized if the operation that suitable for it is invoked.
  /*! \brief CSR graph that stores reverse edges */
  CSRPtr in_csr_;
  /*! \brief CSR representation */
  CSRPtr out_csr_;
  /*! \brief COO representation */
  COOPtr coo_;
375
376
377
  /*!
   * \brief Storage format restriction.
   */
378
  dgl_format_code_t formats_;
379
380
381
382
};

};  // namespace dgl

383
384
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
385
386
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::COO, true);
387
388
}  // namespace dmlc

Minjie Wang's avatar
Minjie Wang committed
389
#endif  // DGL_GRAPH_UNIT_GRAPH_H_