immutable_graph.h 31.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file dgl/immutable_graph.h
 * \brief DGL immutable graph index class.
 */
#ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_

#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
14
#include <algorithm>
15
16
#include "runtime/ndarray.h"
#include "graph_interface.h"
17
#include "lazy.h"
18
19
20

namespace dgl {

21
22
23
24
25
class CSR;
class COO;
typedef std::shared_ptr<CSR> CSRPtr;
typedef std::shared_ptr<COO> COOPtr;

26
27
28
class ImmutableGraph;
typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;

29
/*!
30
 * \brief Graph class stored using CSR structure.
31
 */
32
class CSR : public GraphInterface {
33
 public:
34
35
36
37
38
39
  // Create a csr graph that has the given number of verts and edges.
  CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph);
  // Create a csr graph whose memory is stored in the shared memory
  //   that has the given number of verts and edges.
  CSR(const std::string &shared_mem_name,
      int64_t num_vertices, int64_t num_edges, bool is_multigraph);
40

41
42
43
  // Create a csr graph that shares the given indptr and indices.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);
44
45
46
47
48
49
50

  // Create a csr graph by data iterator
  template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
  CSR(int64_t num_vertices, int64_t num_edges,
      IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
      bool is_multigraph);

51
52
53
54
55
56
57
58
59
60
  // Create a csr graph whose memory is stored in the shared memory
  //   and the structure is given by the indptr and indcies.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &shared_mem_name);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
      const std::string &shared_mem_name);

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
61

62
63
64
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
65

66
67
68
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
69

70
71
72
  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
73

74
  DLContext Context() const override {
75
    return adj_.indptr->ctx;
76
77
78
  }

  uint8_t NumBits() const override {
79
    return adj_.indices->dtype.bits;
80
81
  }

82
  bool IsMultigraph() const override;
83

84
85
86
  bool IsReadonly() const override {
    return true;
  }
87

88
  uint64_t NumVertices() const override {
89
    return adj_.indptr->shape[0] - 1;
90
  }
91

92
  uint64_t NumEdges() const override {
93
94
95
96
97
98
    return adj_.indices->shape[0];
  }

  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return {};
99
  }
100

101
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
102

103
104
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;

105
106
107
108
109
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "CSR graph does not support efficient predecessor query."
      << " Please use successors on the reverse CSR graph.";
    return {};
  }
110

111
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
112

113
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;
114

115
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;
116

117
118
119
120
121
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdge."
      << " Please use COO graph.";
    return {};
  }
122

123
124
125
126
127
  EdgeArray FindEdges(IdArray eids) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdges."
      << " Please use COO graph.";
    return {};
  }
128

129
130
131
132
133
  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
134

135
136
137
138
139
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
140

141
  EdgeArray OutEdges(dgl_id_t vid) const override;
142

143
  EdgeArray OutEdges(IdArray vids) const override;
144

145
  EdgeArray Edges(const std::string &order = "") const override;
146

147
148
149
150
151
  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return 0;
  }
152

153
154
155
156
157
  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return {};
  }
Da Zheng's avatar
Da Zheng committed
158

159
  uint64_t OutDegree(dgl_id_t vid) const override {
160
    return aten::CSRGetRowNNZ(adj_, vid);
161
162
163
164
165
166
  }

  DegreeArray OutDegrees(IdArray vids) const override;

  Subgraph VertexSubgraph(IdArray vids) const override;

167
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
168
169
170
171
172
    LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
      << " Please use COO graph instead.";
    return {};
  }

173
  DGLIdIters SuccVec(dgl_id_t vid) const override;
174

175
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient PredVec."
      << " Please use SuccVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
      << " Please use OutEdgeVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
191
    return {adj_.indptr, adj_.indices, adj_.data};
192
193
  }

194
195
196
197
198
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return !shared_mem_name_.empty();
  }

199
200
201
202
203
204
205
206
207
208
209
  /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
  CSRPtr Transpose() const;

  /*! \brief Convert this CSR to COO */
  COOPtr ToCOO() const;

  /*!
   * \return the csr matrix that represents this graph.
   * \note The csr matrix shares the storage with this graph.
   *       The data field of the CSR matrix stores the edge ids.
   */
210
211
  aten::CSRMatrix ToCSRMatrix() const {
    return adj_;
212
213
  }

214
215
216
217
218
219
220
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  CSR CopyTo(const DLContext& ctx) const;

221
222
223
224
225
226
227
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  CSR CopyToSharedMem(const std::string &name) const;

228
229
230
231
232
233
234
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  CSR AsNumBits(uint8_t bits) const;

235
236
  // member getters

237
  IdArray indptr() const { return adj_.indptr; }
238

239
  IdArray indices() const { return adj_.indices; }
240

241
  IdArray edge_ids() const { return adj_.data; }
242

Da Zheng's avatar
Da Zheng committed
243
244
245
246
247
248
249
  void SortCSR() {
    if (adj_.sorted)
      return;
    aten::CSRSort(adj_);
    adj_.sorted = true;
  }

250
251
 private:
  /*! \brief prive default constructor */
Da Zheng's avatar
Da Zheng committed
252
  CSR() {adj_.sorted = false;}
253

254
255
256
  // The internal CSR adjacency matrix.
  // The data field stores edge ids.
  aten::CSRMatrix adj_;
257
258

  // whether the graph is a multi-graph
259
  Lazy<bool> is_multigraph_;
260
261
262
263

  // The name of the shared memory to store data.
  // If it's empty, data isn't stored in shared memory.
  std::string shared_mem_name_;
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
};

class COO : public GraphInterface {
 public:
  // Create a coo graph that shares the given src and dst
  COO(int64_t num_vertices, IdArray src, IdArray dst);
  COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph);

  // TODO(da): add constructor for creating COO from shared memory

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

290
  DLContext Context() const override {
291
    return adj_.row->ctx;
292
293
294
  }

  uint8_t NumBits() const override {
295
    return adj_.row->dtype.bits;
296
297
  }

298
299
300
301
302
303
304
  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices() const override {
305
    return adj_.num_rows;
306
307
308
  }

  uint64_t NumEdges() const override {
309
    return adj_.row->shape[0];
310
311
312
  }

  bool HasVertex(dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
313
    return vid < NumVertices();
314
315
  }

316
317
318
319
320
  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

321
322
323
324
325
326
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return false;
  }

327
328
329
330
331
332
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Predecessors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Successors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

357
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;
358
359
360
361
362
363
364
365

  EdgeArray FindEdges(IdArray eids) const override;

  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray Edges(const std::string &order = "") const override;

  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  uint64_t OutDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray OutDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  Subgraph VertexSubgraph(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

417
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445

  DGLIdIters SuccVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient SuccVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient PredVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
446
      return {aten::HStack(adj_.col, adj_.row)};
447
    } else {
448
      return {aten::HStack(adj_.row, adj_.col)};
449
    }
450
  }
451

452
453
  /*! \brief Return the transpose of this COO */
  COOPtr Transpose() const {
454
    return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
455
  }
456

457
458
  /*! \brief Convert this COO to CSR */
  CSRPtr ToCSR() const;
459

460
461
462
463
464
  /*!
   * \brief Get the coo matrix that represents this graph.
   * \note The coo matrix shares the storage with this graph.
   *       The data field of the coo matrix is none.
   */
465
466
  aten::COOMatrix ToCOOMatrix() const {
    return adj_;
467
468
  }

469
470
471
472
473
474
475
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  COO CopyTo(const DLContext& ctx) const;

476
477
478
479
480
481
482
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  COO CopyToSharedMem(const std::string &name) const;

483
484
485
486
487
488
489
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  COO AsNumBits(uint8_t bits) const;

490
491
492
493
494
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return false;
  }

495
496
  // member getters

497
  IdArray src() const { return adj_.row; }
498

499
  IdArray dst() const { return adj_.col; }
500
501
502
503
504

 private:
  /* !\brief private default constructor */
  COO() {}

505
506
507
508
  // The internal COO adjacency matrix.
  // The data field is empty
  aten::COOMatrix adj_;

509
  /*! \brief whether the graph is a multi-graph */
510
  Lazy<bool> is_multigraph_;
511
512
513
514
515
516
517
518
519
520
521
};

/*!
 * \brief DGL immutable graph index class.
 *
 * DGL's graph is directed. Vertices are integers enumerated from zero.
 */
class ImmutableGraph: public GraphInterface {
 public:
  /*! \brief Construct an immutable graph from the COO format. */
  explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
522

523
524
525
526
527
528
529
530
531
532
533
534
535
  /*!
   * \brief Construct an immutable graph from the CSR format.
   *
   * For a single graph, we need two CSRs, one stores the in-edges of vertices and
   * the other stores the out-edges of vertices. These two CSRs stores the same edges.
   * The reason we need both is that some operators are faster on in-edge CSR and
   * the other operators are faster on out-edge CSR.
   *
   * However, not both CSRs are required. Technically, one CSR contains all information.
   * Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
   * construct one of the CSRs that runs fast for some operations we expect and construct
   * the other CSR on demand.
   */
536
537
538
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
539
540
  }

541
542
  /*! \brief Construct an immutable graph from one CSR. */
  explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }
543
544
545
546
547
548
549
550
551
552
553

  /*! \brief default copy constructor */
  ImmutableGraph(const ImmutableGraph& other) = default;

#ifndef _MSC_VER
  /*! \brief default move constructor */
  ImmutableGraph(ImmutableGraph&& other) = default;
#else
  ImmutableGraph(ImmutableGraph&& other) {
    this->in_csr_ = other.in_csr_;
    this->out_csr_ = other.out_csr_;
554
    this->coo_ = other.coo_;
555
556
    other.in_csr_ = nullptr;
    other.out_csr_ = nullptr;
557
    other.coo_ = nullptr;
558
559
560
561
562
563
564
565
566
  }
#endif  // _MSC_VER

  /*! \brief default assign constructor */
  ImmutableGraph& operator=(const ImmutableGraph& other) = default;

  /*! \brief default destructor */
  ~ImmutableGraph() = default;

567
  void AddVertices(uint64_t num_vertices) override {
568
569
570
    LOG(FATAL) << "AddVertices isn't supported in ImmutableGraph";
  }

571
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
572
573
574
    LOG(FATAL) << "AddEdge isn't supported in ImmutableGraph";
  }

575
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
576
577
578
    LOG(FATAL) << "AddEdges isn't supported in ImmutableGraph";
  }

579
  void Clear() override {
580
581
582
    LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
  }

583
584
585
586
587
588
589
590
  DLContext Context() const override {
    return AnyGraph()->Context();
  }

  uint8_t NumBits() const override {
    return AnyGraph()->NumBits();
  }

591
592
593
594
  /*!
   * \note not const since we have caches
   * \return whether the graph is a multigraph
   */
595
596
  bool IsMultigraph() const override {
    return AnyGraph()->IsMultigraph();
597
598
599
600
601
  }

  /*!
   * \return whether the graph is read-only
   */
602
  bool IsReadonly() const override {
603
604
605
606
    return true;
  }

  /*! \return the number of vertices in the graph.*/
607
608
  uint64_t NumVertices() const override {
    return AnyGraph()->NumVertices();
609
610
611
  }

  /*! \return the number of edges in the graph.*/
612
613
  uint64_t NumEdges() const override {
    return AnyGraph()->NumEdges();
614
615
616
  }

  /*! \return true if the given vertex is in the graph.*/
617
  bool HasVertex(dgl_id_t vid) const override {
618
619
620
    return vid < NumVertices();
  }

621
622
  BoolArray HasVertices(IdArray vids) const override;

623
  /*! \return true if the given edge is in the graph.*/
624
625
626
627
628
629
630
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgeBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgeBetween(src, dst);
    }
  }
631

632
633
634
635
636
637
638
639
  BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgesBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgesBetween(src, dst);
    }
  }

640
641
642
643
644
645
  /*!
   * \brief Find the predecessors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the predecessor id array.
   */
646
647
648
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetInCSR()->Successors(vid, radius);
  }
649
650
651
652
653
654
655

  /*!
   * \brief Find the successors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the successor id array.
   */
656
657
658
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetOutCSR()->Successors(vid, radius);
  }
659
660
661
662
663
664
665
666
667

  /*!
   * \brief Get all edge ids between the two given endpoints
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   * \param src The source vertex.
   * \param dst The destination vertex.
   * \return the edge id array.
   */
668
669
670
671
672
673
674
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->EdgeId(dst, src);
    } else {
      return GetOutCSR()->EdgeId(src, dst);
    }
  }
675
676
677
678
679
680
681
682
683
684

  /*!
   * \brief Get all edge ids between the given endpoint pairs.
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   *       If duplicate pairs exist, the returned edge IDs will also duplicate.
   *       The order of returned edge IDs will follow the order of src-dst pairs
   *       first, and ties are broken by the order of edge ID.
   * \return EdgeArray containing all edges between all pairs.
   */
685
686
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    if (in_csr_) {
687
688
      EdgeArray edges = in_csr_->EdgeIds(dst, src);
      return EdgeArray{edges.dst, edges.src, edges.id};
689
690
691
692
    } else {
      return GetOutCSR()->EdgeIds(src, dst);
    }
  }
693
694
695
696
697
698

  /*!
   * \brief Find the edge ID and return the pair of endpoints
   * \param eid The edge ID
   * \return a pair whose first element is the source and the second the destination.
   */
699
700
701
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    return GetCOO()->FindEdge(eid);
  }
702
703
704
705
706
707

  /*!
   * \brief Find the edge IDs and return their source and target node IDs.
   * \param eids The edge ID array.
   * \return EdgeArray containing all edges with id in eid.  The order is preserved.
   */
708
709
710
  EdgeArray FindEdges(IdArray eids) const override {
    return GetCOO()->FindEdges(eids);
  }
711
712
713
714
715
716
717

  /*!
   * \brief Get the in edges of the vertex.
   * \note The returned dst id array is filled with vid.
   * \param vid The vertex id.
   * \return the edges
   */
718
719
720
  EdgeArray InEdges(dgl_id_t vid) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vid);
    return {ret.dst, ret.src, ret.id};
721
722
723
724
725
726
727
  }

  /*!
   * \brief Get the in edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
728
729
730
  EdgeArray InEdges(IdArray vids) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vids);
    return {ret.dst, ret.src, ret.id};
731
732
733
734
735
736
737
738
  }

  /*!
   * \brief Get the out edges of the vertex.
   * \note The returned src id array is filled with vid.
   * \param vid The vertex id.
   * \return the id arrays of the two endpoints of the edges.
   */
739
740
  EdgeArray OutEdges(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdges(vid);
741
742
743
744
745
746
747
  }

  /*!
   * \brief Get the out edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
748
749
  EdgeArray OutEdges(IdArray vids) const override {
    return GetOutCSR()->OutEdges(vids);
750
751
752
753
754
755
756
757
758
  }

  /*!
   * \brief Get all the edges in the graph.
   * \note If sorted is true, the returned edges list is sorted by their src and
   *       dst ids. Otherwise, they are in their edge id order.
   * \param sorted Whether the returned edge list is sorted by their src and dst ids
   * \return the id arrays of the two endpoints of the edges.
   */
759
  EdgeArray Edges(const std::string &order = "") const override;
760
761
762
763
764
765

  /*!
   * \brief Get the in degree of the given vertex.
   * \param vid The vertex id.
   * \return the in degree
   */
766
767
  uint64_t InDegree(dgl_id_t vid) const override {
    return GetInCSR()->OutDegree(vid);
768
769
770
771
772
773
774
  }

  /*!
   * \brief Get the in degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the in degree array
   */
775
776
  DegreeArray InDegrees(IdArray vids) const override {
    return GetInCSR()->OutDegrees(vids);
777
778
779
780
781
782
783
  }

  /*!
   * \brief Get the out degree of the given vertex.
   * \param vid The vertex id.
   * \return the out degree
   */
784
785
  uint64_t OutDegree(dgl_id_t vid) const override {
    return GetOutCSR()->OutDegree(vid);
786
787
788
789
790
791
792
  }

  /*!
   * \brief Get the out degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the out degree array
   */
793
794
  DegreeArray OutDegrees(IdArray vids) const override {
    return GetOutCSR()->OutDegrees(vids);
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
  }

  /*!
   * \brief Construct the induced subgraph of the given vertices.
   *
   * The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
   * selecting all of the edges from the original graph that connect two vertices in V'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the vertices preserve the order of the given id array, while the local index
   * of the edges preserve the index order in the original graph. Vertices not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param vids The vertices in the subgraph.
   * \return the induced subgraph
   */
813
  Subgraph VertexSubgraph(IdArray vids) const override;
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830

  /*!
   * \brief Construct the induced edge subgraph of the given edges.
   *
   * The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
   * selecting all of the nodes from the original graph that are endpoints in E'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the edges preserve the order of the given id array, while the local index
   * of the vertices preserve the index order in the original graph. Edges not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param eids The edges in the subgraph.
   * \return the induced edge subgraph
   */
831
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
832
833
834
835
836
837

  /*!
   * \brief Return the successor vector
   * \param vid The vertex id.
   * \return the successor vector
   */
838
839
  DGLIdIters SuccVec(dgl_id_t vid) const override {
    return GetOutCSR()->SuccVec(vid);
840
841
842
843
844
845
846
  }

  /*!
   * \brief Return the out edge id vector
   * \param vid The vertex id.
   * \return the out edge id vector
   */
847
848
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdgeVec(vid);
849
850
851
852
853
854
855
  }

  /*!
   * \brief Return the predecessor vector
   * \param vid The vertex id.
   * \return the predecessor vector
   */
856
857
  DGLIdIters PredVec(dgl_id_t vid) const override {
    return GetInCSR()->SuccVec(vid);
858
859
860
861
862
863
864
  }

  /*!
   * \brief Return the in edge id vector
   * \param vid The vertex id.
   * \return the in edge id vector
   */
865
866
  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    return GetInCSR()->OutEdgeVec(vid);
867
868
869
870
871
872
873
874
875
876
877
  }

  /*!
   * \brief Get the adjacency matrix of the graph.
   *
   * By default, a row of returned adjacency matrix represents the destination
   * of an edge and the column represents the source.
   * \param transpose A flag to transpose the returned adjacency matrix.
   * \param fmt the format of the returned adjacency matrix.
   * \return a vector of three IdArray.
   */
878
  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
879

880
  /* !\brief Return in csr. If not exist, transpose the other one.*/
881
  CSRPtr GetInCSR() const;
882
883

  /* !\brief Return out csr. If not exist, transpose the other one.*/
884
  CSRPtr GetOutCSR() const;
885

886
  /* !\brief Return coo. If not exist, create from csr.*/
887
  COOPtr GetCOO() const;
888

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
  /*! \brief Create an immutable graph from CSR. */
  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      bool multigraph, const std::string &edge_dir);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &edge_dir, const std::string &shared_mem_name);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      bool multigraph, const std::string &edge_dir,
      const std::string &shared_mem_name);

  static ImmutableGraphPtr CreateFromCSR(
      const std::string &shared_mem_name, size_t num_vertices,
      size_t num_edges, bool multigraph,
      const std::string &edge_dir);

  /*! \brief Create an immutable graph from COO. */
  static ImmutableGraphPtr CreateFromCOO(
      int64_t num_vertices, IdArray src, IdArray dst);

  static ImmutableGraphPtr CreateFromCOO(
      int64_t num_vertices, IdArray src, IdArray dst, bool multigraph);

918
919
920
921
922
923
924
925
926
  /*!
   * \brief Convert the given graph to an immutable graph.
   *
   * If the graph is already an immutable graph. The result graph will share
   * the storage with the given one.
   *
   * \param graph The input graph.
   * \return an immutable graph object.
   */
927
  static ImmutableGraphPtr ToImmutable(GraphPtr graph);
928
929
930
931
932
933

  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
934
  static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);
935

936
937
938
939
940
941
  /*!
   * \brief Copy data to shared memory.
   * \param edge_dir the graph of the specific edge direction to be copied.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
942
943
  static ImmutableGraphPtr CopyToSharedMem(
      ImmutableGraphPtr g, const std::string &edge_dir, const std::string &name);
944

945
946
947
948
949
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
950
  static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);
951

952
953
954
955
956
957
958
959
  /*!
   * \brief Return a new graph with all the edges reversed.
   *
   * The returned graph preserves the vertex and edge index in the original graph.
   *
   * \return the reversed graph
   */
  ImmutableGraphPtr Reverse() const;
960

961
962
963
964
965
966
  /*! \return Load HeteroGraph from stream, using CSRMatrix*/
  bool Load(dmlc::Stream* fs);

  /*! \return Save HeteroGraph to stream, using CSRMatrix */
  void Save(dmlc::Stream* fs) const;

Da Zheng's avatar
Da Zheng committed
967
968
969
970
971
  void SortCSR() {
    GetInCSR()->SortCSR();
    GetOutCSR()->SortCSR();
  }

972
973
974
 protected:
  /* !\brief internal default constructor */
  ImmutableGraph() {}
975

976
977
978
979
980
  /* !\brief internal constructor for all the members */
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
    : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
    CHECK(AnyGraph()) << "At least one graph structure should exist.";
  }
981

982
983
984
985
986
987
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
    this->shared_mem_name_ = shared_mem_name;
  }

988
989
990
991
992
993
994
995
996
997
  /* !\brief return pointer to any available graph structure */
  GraphPtr AnyGraph() const {
    if (in_csr_) {
      return in_csr_;
    } else if (out_csr_) {
      return out_csr_;
    } else {
      return coo_;
    }
  }
998

999
1000
1001
1002
1003
1004
  // Store the in csr (i.e, the reverse csr)
  CSRPtr in_csr_;
  // Store the out csr (i.e, the normal csr)
  CSRPtr out_csr_;
  // Store the edge list indexed by edge id (COO)
  COOPtr coo_;
1005
1006
1007
1008

  // The name of shared memory for this graph.
  // If it's empty, the graph isn't stored in shared memory.
  std::string shared_mem_name_;
1009
1010
};

1011
1012
1013
1014
1015
1016
// inline implementations

template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
    IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
    bool is_multigraph): is_multigraph_(is_multigraph) {
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
  // TODO(minjie): this should be changed to a device-agnostic implementation
  //   in the future
  adj_.num_rows = num_vertices;
  adj_.num_cols = num_vertices;
  adj_.indptr = aten::NewIdArray(num_vertices + 1);
  adj_.indices = aten::NewIdArray(num_edges);
  adj_.data = aten::NewIdArray(num_edges);
  dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
1027
1028
1029
1030
1031
1032
1033
1034
  for (int64_t i = 0; i < num_vertices + 1; ++i)
    *(indptr_data++) = *(indptr_begin++);
  for (int64_t i = 0; i < num_edges; ++i) {
    *(indices_data++) = *(indices_begin++);
    *(edge_ids_data++) = *(edge_ids_begin++);
  }
}

1035
1036
}  // namespace dgl

1037
1038
1039
1040
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
}  // namespace dmlc

1041
#endif  // DGL_IMMUTABLE_GRAPH_H_