immutable_graph.h 30.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file dgl/immutable_graph.h
 * \brief DGL immutable graph index class.
 */
#ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_

#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
14
#include <algorithm>
15
16
#include "runtime/ndarray.h"
#include "graph_interface.h"
17
#include "lazy.h"
18
19
20

namespace dgl {

21
22
23
24
25
class CSR;
class COO;
typedef std::shared_ptr<CSR> CSRPtr;
typedef std::shared_ptr<COO> COOPtr;

26
/*!
27
 * \brief Graph class stored using CSR structure.
28
 */
29
class CSR : public GraphInterface {
30
 public:
31
32
33
34
35
36
  // Create a csr graph that has the given number of verts and edges.
  CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph);
  // Create a csr graph whose memory is stored in the shared memory
  //   that has the given number of verts and edges.
  CSR(const std::string &shared_mem_name,
      int64_t num_vertices, int64_t num_edges, bool is_multigraph);
37

38
39
40
  // Create a csr graph that shares the given indptr and indices.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);
41
42
43
44
45
46
47

  // Create a csr graph by data iterator
  template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
  CSR(int64_t num_vertices, int64_t num_edges,
      IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
      bool is_multigraph);

48
49
50
51
52
53
54
55
56
57
  // Create a csr graph whose memory is stored in the shared memory
  //   and the structure is given by the indptr and indcies.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &shared_mem_name);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
      const std::string &shared_mem_name);

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
58

59
60
61
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
62

63
64
65
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
66

67
68
69
  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
70

71
72
73
74
75
76
77
78
  DLContext Context() const override {
    return indptr_->ctx;
  }

  uint8_t NumBits() const override {
    return indices_->dtype.bits;
  }

79
  bool IsMultigraph() const override;
80

81
82
83
  bool IsReadonly() const override {
    return true;
  }
84

85
86
87
  uint64_t NumVertices() const override {
    return indptr_->shape[0] - 1;
  }
88

89
90
91
  uint64_t NumEdges() const override {
    return indices_->shape[0];
  }
92

93
  bool HasVertex(dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
94
    return vid < NumVertices();
95
  }
96

97
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
98

99
100
101
102
103
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "CSR graph does not support efficient predecessor query."
      << " Please use successors on the reverse CSR graph.";
    return {};
  }
104

105
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
106

107
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;
108

109
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;
110

111
112
113
114
115
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdge."
      << " Please use COO graph.";
    return {};
  }
116

117
118
119
120
121
  EdgeArray FindEdges(IdArray eids) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdges."
      << " Please use COO graph.";
    return {};
  }
122

123
124
125
126
127
  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
128

129
130
131
132
133
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
134

135
  EdgeArray OutEdges(dgl_id_t vid) const override;
136

137
  EdgeArray OutEdges(IdArray vids) const override;
138

139
  EdgeArray Edges(const std::string &order = "") const override;
140

141
142
143
144
145
  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return 0;
  }
146

147
148
149
150
151
  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return {};
  }
Da Zheng's avatar
Da Zheng committed
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
  uint64_t OutDegree(dgl_id_t vid) const override {
    const int64_t* indptr_data = static_cast<int64_t*>(indptr_->data);
    return indptr_data[vid + 1] - indptr_data[vid];
  }

  DegreeArray OutDegrees(IdArray vids) const override;

  Subgraph VertexSubgraph(IdArray vids) const override;

  Subgraph EdgeSubgraph(IdArray eids) const override {
    LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
      << " Please use COO graph instead.";
    return {};
  }

  GraphPtr Reverse() const override {
    return Transpose();
  }

  DGLIdIters SuccVec(dgl_id_t vid) const override {
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient PredVec."
      << " Please use SuccVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
      << " Please use OutEdgeVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  GraphInterface *Reset() override {
    CSR* gptr = new CSR();
    *gptr = std::move(*this);
    return gptr;
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {indptr_, indices_, edge_ids_};
  }

  /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
  CSRPtr Transpose() const;

  /*! \brief Convert this CSR to COO */
  COOPtr ToCOO() const;

  /*!
   * \return the csr matrix that represents this graph.
   * \note The csr matrix shares the storage with this graph.
   *       The data field of the CSR matrix stores the edge ids.
   */
  CSRMatrix ToCSRMatrix() const {
    return CSRMatrix{indptr_, indices_, edge_ids_};
  }

226
227
228
229
230
231
232
233
234
235
236
237
238
239
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  CSR CopyTo(const DLContext& ctx) const;

  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  CSR AsNumBits(uint8_t bits) const;

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  // member getters

  IdArray indptr() const { return indptr_; }

  IdArray indices() const { return indices_; }

  IdArray edge_ids() const { return edge_ids_; }

 private:
  /*! \brief prive default constructor */
  CSR() {}

  // The CSR arrays.
  //  - The index is 0-based.
  //  - The out edges of vertex v is stored from `indices_[indptr_[v]]` to
  //    `indices_[indptr_[v+1]]` (exclusive).
  //  - The indices are *not* necessarily sorted.
  // TODO(minjie): in the future, we should separate CSR and graph. A general CSR
  //   is not necessarily a graph, but graph operations could be implemented by
  //   CSR matrix operations. CSR matrix operations would be backed by different
  //   devices (CPU, CUDA, ...), while graph interface will not be aware of that.
  IdArray indptr_, indices_, edge_ids_;

  // whether the graph is a multi-graph
  LazyObject<bool> is_multigraph_;
};

class COO : public GraphInterface {
 public:
  // Create a coo graph that shares the given src and dst
  COO(int64_t num_vertices, IdArray src, IdArray dst);
  COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph);

  // TODO(da): add constructor for creating COO from shared memory

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

291
292
293
294
295
296
297
298
  DLContext Context() const override {
    return src_->ctx;
  }

  uint8_t NumBits() const override {
    return src_->dtype.bits;
  }

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices() const override {
    return num_vertices_;
  }

  uint64_t NumEdges() const override {
    return src_->shape[0];
  }

  bool HasVertex(dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
314
    return vid < NumVertices();
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
  }

  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return false;
  }

  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Predecessors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Successors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
Minjie Wang's avatar
Minjie Wang committed
348
    CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
349
350
351
352
353
354
355
356
357
358
359
360
    const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
    const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
    return std::make_pair(src_data[eid], dst_data[eid]);
  }

  EdgeArray FindEdges(IdArray eids) const override;

  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }
361

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray Edges(const std::string &order = "") const override;

  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  uint64_t OutDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray OutDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  Subgraph VertexSubgraph(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  Subgraph EdgeSubgraph(IdArray eids) const override;

  GraphPtr Reverse() const override {
    return Transpose();
  }

  DGLIdIters SuccVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient SuccVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient PredVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  GraphInterface *Reset() override {
    COO* gptr = new COO();
    *gptr = std::move(*this);
    return gptr;
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {HStack(dst_, src_)};
    } else {
      return {HStack(src_, dst_)};
454
    }
455
  }
456

457
458
459
460
  /*! \brief Return the transpose of this COO */
  COOPtr Transpose() const {
    return COOPtr(new COO(num_vertices_, dst_, src_));
  }
461

462
463
  /*! \brief Convert this COO to CSR */
  CSRPtr ToCSR() const;
464

465
466
467
468
469
470
471
472
473
  /*!
   * \brief Get the coo matrix that represents this graph.
   * \note The coo matrix shares the storage with this graph.
   *       The data field of the coo matrix is none.
   */
  COOMatrix ToCOOMatrix() const {
    return COOMatrix{src_, dst_, {}};
  }

474
475
476
477
478
479
480
481
482
483
484
485
486
487
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  COO CopyTo(const DLContext& ctx) const;

  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  COO AsNumBits(uint8_t bits) const;

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
  // member getters

  IdArray src() const { return src_; }

  IdArray dst() const { return dst_; }

 private:
  /* !\brief private default constructor */
  COO() {}

  /*! \brief number of vertices */
  int64_t num_vertices_;
  /*! \brief coordinate arrays */
  IdArray src_, dst_;
  /*! \brief whether the graph is a multi-graph */
  LazyObject<bool> is_multigraph_;
};

/*!
 * \brief DGL immutable graph index class.
 *
 * DGL's graph is directed. Vertices are integers enumerated from zero.
 */
class ImmutableGraph: public GraphInterface {
 public:
  /*! \brief Construct an immutable graph from the COO format. */
  explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
515
516
517
518
519
520
521
522
523
524
525
526
527
  /*!
   * \brief Construct an immutable graph from the CSR format.
   *
   * For a single graph, we need two CSRs, one stores the in-edges of vertices and
   * the other stores the out-edges of vertices. These two CSRs stores the same edges.
   * The reason we need both is that some operators are faster on in-edge CSR and
   * the other operators are faster on out-edge CSR.
   *
   * However, not both CSRs are required. Technically, one CSR contains all information.
   * Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
   * construct one of the CSRs that runs fast for some operations we expect and construct
   * the other CSR on demand.
   */
528
529
530
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
531
532
  }

533
534
  /*! \brief Construct an immutable graph from one CSR. */
  explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }
535
536
537
538
539
540
541
542
543
544
545

  /*! \brief default copy constructor */
  ImmutableGraph(const ImmutableGraph& other) = default;

#ifndef _MSC_VER
  /*! \brief default move constructor */
  ImmutableGraph(ImmutableGraph&& other) = default;
#else
  ImmutableGraph(ImmutableGraph&& other) {
    this->in_csr_ = other.in_csr_;
    this->out_csr_ = other.out_csr_;
546
    this->coo_ = other.coo_;
547
548
    other.in_csr_ = nullptr;
    other.out_csr_ = nullptr;
549
    other.coo_ = nullptr;
550
551
552
553
554
555
556
557
558
  }
#endif  // _MSC_VER

  /*! \brief default assign constructor */
  ImmutableGraph& operator=(const ImmutableGraph& other) = default;

  /*! \brief default destructor */
  ~ImmutableGraph() = default;

559
  void AddVertices(uint64_t num_vertices) override {
560
561
562
    LOG(FATAL) << "AddVertices isn't supported in ImmutableGraph";
  }

563
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
564
565
566
    LOG(FATAL) << "AddEdge isn't supported in ImmutableGraph";
  }

567
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
568
569
570
    LOG(FATAL) << "AddEdges isn't supported in ImmutableGraph";
  }

571
  void Clear() override {
572
573
574
    LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
  }

575
576
577
578
579
580
581
582
  DLContext Context() const override {
    return AnyGraph()->Context();
  }

  uint8_t NumBits() const override {
    return AnyGraph()->NumBits();
  }

583
584
585
586
  /*!
   * \note not const since we have caches
   * \return whether the graph is a multigraph
   */
587
588
  bool IsMultigraph() const override {
    return AnyGraph()->IsMultigraph();
589
590
591
592
593
  }

  /*!
   * \return whether the graph is read-only
   */
594
  bool IsReadonly() const override {
595
596
597
598
    return true;
  }

  /*! \return the number of vertices in the graph.*/
599
600
  uint64_t NumVertices() const override {
    return AnyGraph()->NumVertices();
601
602
603
  }

  /*! \return the number of edges in the graph.*/
604
605
  uint64_t NumEdges() const override {
    return AnyGraph()->NumEdges();
606
607
608
  }

  /*! \return true if the given vertex is in the graph.*/
609
  bool HasVertex(dgl_id_t vid) const override {
610
611
612
613
    return vid < NumVertices();
  }

  /*! \return true if the given edge is in the graph.*/
614
615
616
617
618
619
620
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgeBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgeBetween(src, dst);
    }
  }
621
622
623
624
625
626
627

  /*!
   * \brief Find the predecessors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the predecessor id array.
   */
628
629
630
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetInCSR()->Successors(vid, radius);
  }
631
632
633
634
635
636
637

  /*!
   * \brief Find the successors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the successor id array.
   */
638
639
640
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetOutCSR()->Successors(vid, radius);
  }
641
642
643
644
645
646
647
648
649

  /*!
   * \brief Get all edge ids between the two given endpoints
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   * \param src The source vertex.
   * \param dst The destination vertex.
   * \return the edge id array.
   */
650
651
652
653
654
655
656
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->EdgeId(dst, src);
    } else {
      return GetOutCSR()->EdgeId(src, dst);
    }
  }
657
658
659
660
661
662
663
664
665
666

  /*!
   * \brief Get all edge ids between the given endpoint pairs.
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   *       If duplicate pairs exist, the returned edge IDs will also duplicate.
   *       The order of returned edge IDs will follow the order of src-dst pairs
   *       first, and ties are broken by the order of edge ID.
   * \return EdgeArray containing all edges between all pairs.
   */
667
668
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    if (in_csr_) {
669
670
      EdgeArray edges = in_csr_->EdgeIds(dst, src);
      return EdgeArray{edges.dst, edges.src, edges.id};
671
672
673
674
    } else {
      return GetOutCSR()->EdgeIds(src, dst);
    }
  }
675
676
677
678
679
680

  /*!
   * \brief Find the edge ID and return the pair of endpoints
   * \param eid The edge ID
   * \return a pair whose first element is the source and the second the destination.
   */
681
682
683
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    return GetCOO()->FindEdge(eid);
  }
684
685
686
687
688
689

  /*!
   * \brief Find the edge IDs and return their source and target node IDs.
   * \param eids The edge ID array.
   * \return EdgeArray containing all edges with id in eid.  The order is preserved.
   */
690
691
692
  EdgeArray FindEdges(IdArray eids) const override {
    return GetCOO()->FindEdges(eids);
  }
693
694
695
696
697
698
699

  /*!
   * \brief Get the in edges of the vertex.
   * \note The returned dst id array is filled with vid.
   * \param vid The vertex id.
   * \return the edges
   */
700
701
702
  EdgeArray InEdges(dgl_id_t vid) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vid);
    return {ret.dst, ret.src, ret.id};
703
704
705
706
707
708
709
  }

  /*!
   * \brief Get the in edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
710
711
712
  EdgeArray InEdges(IdArray vids) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vids);
    return {ret.dst, ret.src, ret.id};
713
714
715
716
717
718
719
720
  }

  /*!
   * \brief Get the out edges of the vertex.
   * \note The returned src id array is filled with vid.
   * \param vid The vertex id.
   * \return the id arrays of the two endpoints of the edges.
   */
721
722
  EdgeArray OutEdges(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdges(vid);
723
724
725
726
727
728
729
  }

  /*!
   * \brief Get the out edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
730
731
  EdgeArray OutEdges(IdArray vids) const override {
    return GetOutCSR()->OutEdges(vids);
732
733
734
735
736
737
738
739
740
  }

  /*!
   * \brief Get all the edges in the graph.
   * \note If sorted is true, the returned edges list is sorted by their src and
   *       dst ids. Otherwise, they are in their edge id order.
   * \param sorted Whether the returned edge list is sorted by their src and dst ids
   * \return the id arrays of the two endpoints of the edges.
   */
741
  EdgeArray Edges(const std::string &order = "") const override;
742
743
744
745
746
747

  /*!
   * \brief Get the in degree of the given vertex.
   * \param vid The vertex id.
   * \return the in degree
   */
748
749
  uint64_t InDegree(dgl_id_t vid) const override {
    return GetInCSR()->OutDegree(vid);
750
751
752
753
754
755
756
  }

  /*!
   * \brief Get the in degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the in degree array
   */
757
758
  DegreeArray InDegrees(IdArray vids) const override {
    return GetInCSR()->OutDegrees(vids);
759
760
761
762
763
764
765
  }

  /*!
   * \brief Get the out degree of the given vertex.
   * \param vid The vertex id.
   * \return the out degree
   */
766
767
  uint64_t OutDegree(dgl_id_t vid) const override {
    return GetOutCSR()->OutDegree(vid);
768
769
770
771
772
773
774
  }

  /*!
   * \brief Get the out degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the out degree array
   */
775
776
  DegreeArray OutDegrees(IdArray vids) const override {
    return GetOutCSR()->OutDegrees(vids);
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
  }

  /*!
   * \brief Construct the induced subgraph of the given vertices.
   *
   * The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
   * selecting all of the edges from the original graph that connect two vertices in V'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the vertices preserve the order of the given id array, while the local index
   * of the edges preserve the index order in the original graph. Vertices not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param vids The vertices in the subgraph.
   * \return the induced subgraph
   */
795
  Subgraph VertexSubgraph(IdArray vids) const override;
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812

  /*!
   * \brief Construct the induced edge subgraph of the given edges.
   *
   * The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
   * selecting all of the nodes from the original graph that are endpoints in E'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the edges preserve the order of the given id array, while the local index
   * of the vertices preserve the index order in the original graph. Edges not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param eids The edges in the subgraph.
   * \return the induced edge subgraph
   */
813
  Subgraph EdgeSubgraph(IdArray eids) const override;
814
815
816
817
818
819
820
821

  /*!
   * \brief Return a new graph with all the edges reversed.
   *
   * The returned graph preserves the vertex and edge index in the original graph.
   *
   * \return the reversed graph
   */
822
823
824
825
826
827
  GraphPtr Reverse() const override {
    if (coo_) {
      return GraphPtr(new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));
    } else {
      return GraphPtr(new ImmutableGraph(out_csr_, in_csr_));
    }
828
829
830
831
832
833
834
  }

  /*!
   * \brief Return the successor vector
   * \param vid The vertex id.
   * \return the successor vector
   */
835
836
  DGLIdIters SuccVec(dgl_id_t vid) const override {
    return GetOutCSR()->SuccVec(vid);
837
838
839
840
841
842
843
  }

  /*!
   * \brief Return the out edge id vector
   * \param vid The vertex id.
   * \return the out edge id vector
   */
844
845
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdgeVec(vid);
846
847
848
849
850
851
852
  }

  /*!
   * \brief Return the predecessor vector
   * \param vid The vertex id.
   * \return the predecessor vector
   */
853
854
  DGLIdIters PredVec(dgl_id_t vid) const override {
    return GetInCSR()->SuccVec(vid);
855
856
857
858
859
860
861
  }

  /*!
   * \brief Return the in edge id vector
   * \param vid The vertex id.
   * \return the in edge id vector
   */
862
863
  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    return GetInCSR()->OutEdgeVec(vid);
864
865
866
867
868
869
  }

  /*!
   * \brief Reset the data in the graph and move its data to the returned graph object.
   * \return a raw pointer to the graph object.
   */
870
  GraphInterface *Reset() override {
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    ImmutableGraph* gptr = new ImmutableGraph();
    *gptr = std::move(*this);
    return gptr;
  }

  /*!
   * \brief Get the adjacency matrix of the graph.
   *
   * By default, a row of returned adjacency matrix represents the destination
   * of an edge and the column represents the source.
   * \param transpose A flag to transpose the returned adjacency matrix.
   * \param fmt the format of the returned adjacency matrix.
   * \return a vector of three IdArray.
   */
885
  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
886

887
888
889
890
891
892
893
894
895
  /* !\brief Return in csr. If not exist, transpose the other one.*/
  CSRPtr GetInCSR() const {
    if (!in_csr_) {
      if (out_csr_) {
        const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      } else {
        CHECK(coo_) << "None of CSR, COO exist";
        const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
      }
896
    }
897
    return in_csr_;
898
  }
899
900
901
902
903
904
905
906
907
908

  /* !\brief Return out csr. If not exist, transpose the other one.*/
  CSRPtr GetOutCSR() const {
    if (!out_csr_) {
      if (in_csr_) {
        const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      } else {
        CHECK(coo_) << "None of CSR, COO exist";
        const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
      }
909
    }
910
    return out_csr_;
911
912
  }

913
914
915
916
917
918
919
920
921
  /* !\brief Return coo. If not exist, create from csr.*/
  COOPtr GetCOO() const {
    if (!coo_) {
      if (in_csr_) {
        const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
      } else {
        CHECK(out_csr_) << "Both CSR are missing.";
        const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
      }
922
    }
923
    return coo_;
924
925
  }

926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
  /*!
   * \brief Convert the given graph to an immutable graph.
   *
   * If the graph is already an immutable graph. The result graph will share
   * the storage with the given one.
   *
   * \param graph The input graph.
   * \return an immutable graph object.
   */
  static ImmutableGraph ToImmutable(const GraphInterface* graph);

  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  ImmutableGraph CopyTo(const DLContext& ctx) const;

  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  ImmutableGraph AsNumBits(uint8_t bits) const;

951
952
953
 protected:
  /* !\brief internal default constructor */
  ImmutableGraph() {}
954

955
956
957
958
959
  /* !\brief internal constructor for all the members */
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
    : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
    CHECK(AnyGraph()) << "At least one graph structure should exist.";
  }
960

961
962
963
964
965
966
967
968
969
970
  /* !\brief return pointer to any available graph structure */
  GraphPtr AnyGraph() const {
    if (in_csr_) {
      return in_csr_;
    } else if (out_csr_) {
      return out_csr_;
    } else {
      return coo_;
    }
  }
971

972
973
974
975
976
977
  // Store the in csr (i.e, the reverse csr)
  CSRPtr in_csr_;
  // Store the out csr (i.e, the normal csr)
  CSRPtr out_csr_;
  // Store the edge list indexed by edge id (COO)
  COOPtr coo_;
978
979
};

980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
// inline implementations

template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
    IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
    bool is_multigraph): is_multigraph_(is_multigraph) {
  indptr_ = NewIdArray(num_vertices + 1);
  indices_ = NewIdArray(num_edges);
  edge_ids_ = NewIdArray(num_edges);
  dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
  dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
  dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data);
  for (int64_t i = 0; i < num_vertices + 1; ++i)
    *(indptr_data++) = *(indptr_begin++);
  for (int64_t i = 0; i < num_edges; ++i) {
    *(indices_data++) = *(indices_begin++);
    *(edge_ids_data++) = *(edge_ids_begin++);
  }
}

1000
1001
1002
}  // namespace dgl

#endif  // DGL_IMMUTABLE_GRAPH_H_