unit_graph.h 13 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/unit_graph.h
 * @brief UnitGraph graph
5
6
 */

Minjie Wang's avatar
Minjie Wang committed
7
8
#ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_UNIT_GRAPH_H_
9

10
#include <dgl/array.h>
11
#include <dgl/base_heterograph.h>
12
#include <dgl/lazy.h>
13
14
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
15

16
#include <memory>
17
#include <string>
18
#include <tuple>
19
20
#include <utility>
#include <vector>
21
22

#include "../c_api_common.h"
23
24
25

namespace dgl {

26
class HeteroGraph;
27
28
29
class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr;

30
/**
31
 * @brief UnitGraph graph
32
 *
Minjie Wang's avatar
Minjie Wang committed
33
34
35
36
37
38
39
 * UnitGraph graph is a special type of heterograph which
 * (1) Have two types of nodes: "Src" and "Dst". All the edges are
 *     from "Src" type nodes to "Dst" type nodes, so there is no edge among
 *     nodes of the same type. Thus, its metagraph has two nodes and one edge
 *     between them.
 * (2) Have only one type of nodes and edges. Thus, its metagraph has one node
 *     and one self-loop edge.
40
 */
Minjie Wang's avatar
Minjie Wang committed
41
class UnitGraph : public BaseHeteroGraph {
42
 public:
43
44
45
46
47
48
  // internal data structure
  class COO;
  class CSR;
  typedef std::shared_ptr<COO> COOPtr;
  typedef std::shared_ptr<CSR> CSRPtr;

49
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
50

51
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
52

53
  inline dgl_type_t EdgeType() const { return 0; }
54
55

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
56
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
57
               << "The relation graph is simply this graph itself.";
58
59
60
61
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
62
    LOG(FATAL) << "UnitGraph graph is not mutable.";
63
64
65
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
66
    LOG(FATAL) << "UnitGraph graph is not mutable.";
67
68
69
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
70
    LOG(FATAL) << "UnitGraph graph is not mutable.";
71
72
  }

73
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
74

75
  DGLDataType DataType() const override;
76

77
  DGLContext Context() const override;
78

79
80
  bool IsPinned() const override;

81
82
83
84
  uint8_t NumBits() const override;

  bool IsMultigraph() const override;

85
  bool IsReadonly() const override { return true; }
86
87
88

  uint64_t NumVertices(dgl_type_t vtype) const override;

89
  inline std::vector<int64_t> NumVerticesPerType() const override {
90
91
92
93
    std::vector<int64_t> num_nodes_per_type;
    for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
      num_nodes_per_type.push_back(NumVertices(vtype));
    return num_nodes_per_type;
94
95
  }

96
97
98
99
100
101
  uint64_t NumEdges(dgl_type_t etype) const override;

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

102
103
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
104

105
106
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;
107
108
109
110
111
112
113

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override;

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

114
115
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override;
116
117

  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;
118

119
120
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override;
121
122
123
124
125
126
127
128
129
130
131

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override;

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override;

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;

132
133
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override;
134
135
136
137
138
139
140
141
142
143
144

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override;

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override;

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override;

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;

145
146
147
  // 32bit version functions, patch for SuccVec
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;

148
149
150
151
152
153
154
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

  std::vector<IdArray> GetAdj(
155
      dgl_type_t etype, bool transpose, const std::string& fmt) const override;
156

157
158
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override;
159
160

  HeteroSubgraph EdgeSubgraph(
161
162
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override;
163
164

  // creators
165
  /** @brief Create a graph with no edges */
166
  static HeteroGraphPtr Empty(
167
168
      int64_t num_vtypes, int64_t num_src, int64_t num_dst, DGLDataType dtype,
      DGLContext ctx) {
169
170
171
172
173
    IdArray row = IdArray::Empty({0}, dtype, ctx);
    IdArray col = IdArray::Empty({0}, dtype, ctx);
    return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
  }

174
  /** @brief Create a graph from COO arrays */
Minjie Wang's avatar
Minjie Wang committed
175
  static HeteroGraphPtr CreateFromCOO(
176
177
178
      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
      IdArray col, bool row_sorted = false, bool col_sorted = false,
      dgl_format_code_t formats = ALL_CODE);
179

180
181
  static HeteroGraphPtr CreateFromCOO(
      int64_t num_vtypes, const aten::COOMatrix& mat,
182
      dgl_format_code_t formats = ALL_CODE);
183

184
  /** @brief Create a graph from (out) CSR arrays */
185
  static HeteroGraphPtr CreateFromCSR(
186
187
      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
      IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
188

189
190
  static HeteroGraphPtr CreateFromCSR(
      int64_t num_vtypes, const aten::CSRMatrix& mat,
191
      dgl_format_code_t formats = ALL_CODE);
192

193
  /** @brief Create a graph from (in) CSC arrays */
194
  static HeteroGraphPtr CreateFromCSC(
195
196
      int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
      IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
197
198
199

  static HeteroGraphPtr CreateFromCSC(
      int64_t num_vtypes, const aten::CSRMatrix& mat,
200
      dgl_format_code_t formats = ALL_CODE);
201

202
  /** @brief Convert the graph to use the given number of bits for storage */
203
  static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
204

205
  /** @brief Copy the data to another context */
206
  static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
207

208
  /**
209
210
211
212
213
214
   * @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
   * @note The graph will be pinned inplace. Behavior depends on the current
   * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
   * invalid, will throw an error. The context check is deferred to pinning the
   * NDArray.
   */
215
  void PinMemory_() override;
216

217
  /**
218
219
220
221
222
   * @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
   * @note The graph will be unpinned inplace. Behavior depends on the current
   * context, IsPinned: will be unpinned; others: directly return. The context
   * check is deferred to unpinning the NDArray.
   */
223
224
  void UnpinMemory_();

225
226
227
228
229
230
231
232
233
  /**
   * @brief Create a copy of the current graph in pinned memory.
   * @note The graph will be pinned outplace through PyTorch
   *     CachingHostAllocator, if available. Otherwise, an error will be thrown.
   *     If any of the underlying structures (incsr, outcsr, coo) are already
   *     pinned, the function will simply use its original copy.
   */
  HeteroGraphPtr PinMemory();

234
  /**
235
236
   * @brief Record stream for this graph.
   * @param stream The stream that is using the graph
237
238
239
   */
  void RecordStream(DGLStreamHandle stream) override;

240
  /**
241
   * @brief Create in-edge CSR format of the unit graph.
242
243
244
245
246
   * @param inplace if true and the in-edge CSR format does not exist, the
   * created format will be cached in this object unless the format is
   * restricted.
   * @return Return the in-edge CSR format. Create from other format if not
   * exist.
247
248
   */
  CSRPtr GetInCSR(bool inplace = true) const;
249

250
  /**
251
   * @brief Create out-edge CSR format of the unit graph.
252
253
254
255
256
   * @param inplace if true and the out-edge CSR format does not exist, the
   * created format will be cached in this object unless the format is
   * restricted.
   * @return Return the out-edge CSR format. Create from other format if not
   * exist.
257
258
   */
  CSRPtr GetOutCSR(bool inplace = true) const;
259

260
  /**
261
262
   * @brief Create COO format of the unit graph.
   * @param inplace if true and the COO format does not exist, the created
263
264
   *                format will be cached in this object unless the format is
   * restricted.
265
   * @return Return the COO format. Create from other format if not exist.
266
267
   */
  COOPtr GetCOO(bool inplace = true) const;
268

269
  /** @return Return the COO matrix form */
270
271
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;

272
  /** @return Return the in-edge CSC in the matrix form */
273
  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;
274

275
  /** @return Return the out-edge CSR in the matrix form */
276
  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
277

278
279
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
280
    return SelectFormat(preferred_formats);
281
282
  }

283
  /**
284
285
   * @brief Return the graph in the given format. Perform format conversion if
   * the requested format does not exist.
286
   *
287
   * @return A graph in the requested format.
288
289
290
   */
  HeteroGraphPtr GetFormat(SparseFormat format) const;

291
292
293
  dgl_format_code_t GetCreatedFormats() const override;

  dgl_format_code_t GetAllowedFormats() const override;
294

295
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
296

297
  /** @return Load UnitGraph from stream, using CSRMatrix*/
298
299
  bool Load(dmlc::Stream* fs);

300
  /** @return Save UnitGraph to stream, using CSRMatrix */
301
302
  void Save(dmlc::Stream* fs) const;

303
  /** @brief Creat a LineGraph of self */
304
305
  HeteroGraphPtr LineGraph(bool backtracking) const;

306
  /** @return the reversed graph */
307
308
  UnitGraphPtr Reverse() const;

309
  /** @return the simpled (no-multi-edge) graph
310
311
312
   *          the count recording the number of duplicated edges from the
   * original graph. the edge mapping from the edge IDs of original graph to
   * those of the returned graph.
313
   */
314
  std::tuple<UnitGraphPtr, IdArray, IdArray> ToSimple() const;
315

316
317
318
319
320
321
  void InvalidateCSR();

  void InvalidateCSC();

  void InvalidateCOO();

322
 private:
323
  friend class Serializer;
324
  friend class HeteroGraph;
325
  friend class ImmutableGraph;
326
  friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
327

328
329
330
  // private empty constructor
  UnitGraph() {}

331
  /**
332
333
334
335
336
   * @brief constructor
   * @param metagraph metagraph
   * @param in_csr in edge csr
   * @param out_csr out edge csr
   * @param coo coo
Minjie Wang's avatar
Minjie Wang committed
337
   */
338
339
340
  UnitGraph(
      GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
      dgl_format_code_t formats = ALL_CODE);
341

342
  /**
343
344
345
346
347
348
349
350
351
   * @brief constructor
   * @param num_vtypes number of vertex types (1 or 2)
   * @param metagraph metagraph
   * @param in_csr in edge csr
   * @param out_csr out edge csr
   * @param coo coo
   * @param has_in_csr whether in_csr is valid
   * @param has_out_csr whether out_csr is valid
   * @param has_coo whether coo is valid
352
   */
353
  static HeteroGraphPtr CreateUnitGraphFrom(
354
355
356
      int num_vtypes, const aten::CSRMatrix& in_csr,
      const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo,
      bool has_in_csr, bool has_out_csr, bool has_coo,
357
      dgl_format_code_t formats = ALL_CODE);
358

359
  /** @return Return any existing format. */
360
361
  HeteroGraphPtr GetAny() const;

362
  /**
363
   * @brief Determine which format to use with a preference.
364
   *
365
366
   * If the storage of unit graph is "locked", i.e. no conversion is allowed,
   * then it will return the locked format.
367
368
369
370
   *
   * Otherwise, it will return whatever DGL thinks is the most appropriate given
   * the arguments.
   */
371
  SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;
372

373
  /** @return Whether the graph is hypersparse */
374
375
  bool IsHypersparse() const;

376
377
  GraphPtr AsImmutableGraph() const override;

378
379
  // Graph stored in different format. We use an on-demand strategy: the format
  // is only materialized if the operation that suitable for it is invoked.
380
  /** @brief CSR graph that stores reverse edges */
381
  CSRPtr in_csr_;
382
  /** @brief CSR representation */
383
  CSRPtr out_csr_;
384
  /** @brief COO representation */
385
  COOPtr coo_;
386
  /**
387
   * @brief Storage format restriction.
388
   */
389
  dgl_format_code_t formats_;
390
  /** @brief which streams have recorded the graph */
391
  std::vector<DGLStreamHandle> recorded_streams;
392
393
394
395
};

};  // namespace dgl

396
397
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph, true);
398
399
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::CSR, true);
DMLC_DECLARE_TRAITS(has_saveload, dgl::UnitGraph::COO, true);
400
401
}  // namespace dmlc

Minjie Wang's avatar
Minjie Wang committed
402
#endif  // DGL_GRAPH_UNIT_GRAPH_H_