immutable_graph.h 31.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file dgl/immutable_graph.h
 * \brief DGL immutable graph index class.
 */
#ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_

#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
14
#include <algorithm>
15
#include <memory>
16
17
#include "runtime/ndarray.h"
#include "graph_interface.h"
18
#include "lazy.h"
19
#include "base_heterograph.h"
20
21
22

namespace dgl {

23
24
25
26
27
class CSR;
class COO;
typedef std::shared_ptr<CSR> CSRPtr;
typedef std::shared_ptr<COO> COOPtr;

28
29
30
class ImmutableGraph;
typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;

31
/*!
32
 * \brief Graph class stored using CSR structure.
33
 */
34
class CSR : public GraphInterface {
35
 public:
36
  // Create a csr graph that has the given number of verts and edges.
37
  CSR(int64_t num_vertices, int64_t num_edges);
38
39
40
  // Create a csr graph whose memory is stored in the shared memory
  //   that has the given number of verts and edges.
  CSR(const std::string &shared_mem_name,
41
      int64_t num_vertices, int64_t num_edges);
42

43
44
  // Create a csr graph that shares the given indptr and indices.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
45
46
47
48

  // Create a csr graph by data iterator
  template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
  CSR(int64_t num_vertices, int64_t num_edges,
49
      IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
50

51
52
53
54
55
56
57
58
  // Create a csr graph whose memory is stored in the shared memory
  //   and the structure is given by the indptr and indcies.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &shared_mem_name);

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
59

60
61
62
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
63

64
65
66
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
67

68
69
70
  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
71

72
  DGLContext Context() const override {
73
    return adj_.indptr->ctx;
74
75
76
  }

  uint8_t NumBits() const override {
77
    return adj_.indices->dtype.bits;
78
79
  }

80
  bool IsMultigraph() const override;
81

82
83
84
  bool IsReadonly() const override {
    return true;
  }
85

86
  uint64_t NumVertices() const override {
87
    return adj_.indptr->shape[0] - 1;
88
  }
89

90
  uint64_t NumEdges() const override {
91
92
93
94
95
96
    return adj_.indices->shape[0];
  }

  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return {};
97
  }
98

99
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
100

101
102
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;

103
104
105
106
107
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "CSR graph does not support efficient predecessor query."
      << " Please use successors on the reverse CSR graph.";
    return {};
  }
108

109
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
110

111
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;
112

113
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;
114

115
116
117
118
119
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdge."
      << " Please use COO graph.";
    return {};
  }
120

121
122
123
124
125
  EdgeArray FindEdges(IdArray eids) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdges."
      << " Please use COO graph.";
    return {};
  }
126

127
128
129
130
131
  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
132

133
134
135
136
137
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
138

139
  EdgeArray OutEdges(dgl_id_t vid) const override;
140

141
  EdgeArray OutEdges(IdArray vids) const override;
142

143
  EdgeArray Edges(const std::string &order = "") const override;
144

145
146
147
148
149
  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return 0;
  }
150

151
152
153
154
155
  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return {};
  }
Da Zheng's avatar
Da Zheng committed
156

157
  uint64_t OutDegree(dgl_id_t vid) const override {
158
    return aten::CSRGetRowNNZ(adj_, vid);
159
160
161
162
163
164
  }

  DegreeArray OutDegrees(IdArray vids) const override;

  Subgraph VertexSubgraph(IdArray vids) const override;

165
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
166
167
168
169
170
    LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
      << " Please use COO graph instead.";
    return {};
  }

171
  DGLIdIters SuccVec(dgl_id_t vid) const override;
172

173
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient PredVec."
      << " Please use SuccVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
      << " Please use OutEdgeVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
189
    return {adj_.indptr, adj_.indices, adj_.data};
190
191
  }

192
193
194
195
196
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return !shared_mem_name_.empty();
  }

197
198
199
200
201
202
203
204
205
206
207
  /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
  CSRPtr Transpose() const;

  /*! \brief Convert this CSR to COO */
  COOPtr ToCOO() const;

  /*!
   * \return the csr matrix that represents this graph.
   * \note The csr matrix shares the storage with this graph.
   *       The data field of the CSR matrix stores the edge ids.
   */
208
209
  aten::CSRMatrix ToCSRMatrix() const {
    return adj_;
210
211
  }

212
213
214
215
216
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
217
  CSR CopyTo(const DGLContext& ctx) const;
218

219
220
221
222
223
224
225
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  CSR CopyToSharedMem(const std::string &name) const;

226
227
228
229
230
231
232
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  CSR AsNumBits(uint8_t bits) const;

233
234
  // member getters

235
  IdArray indptr() const { return adj_.indptr; }
236

237
  IdArray indices() const { return adj_.indices; }
238

239
  IdArray edge_ids() const { return adj_.data; }
240

241
242
243
244
245
246
  /*! \return Load CSR from stream */
  bool Load(dmlc::Stream *fs);

  /*! \return Save CSR to stream */
  void Save(dmlc::Stream* fs) const;

247
  void SortCSR() override {
Da Zheng's avatar
Da Zheng committed
248
249
    if (adj_.sorted)
      return;
250
    aten::CSRSort_(&adj_);
Da Zheng's avatar
Da Zheng committed
251
252
  }

253
 private:
254
  friend class Serializer;
255

256
257
  /*! \brief private default constructor */
  CSR() {adj_.sorted = false;}
258
259
260
  // The internal CSR adjacency matrix.
  // The data field stores edge ids.
  aten::CSRMatrix adj_;
261

262
263
264
  // The name of the shared memory to store data.
  // If it's empty, data isn't stored in shared memory.
  std::string shared_mem_name_;
265
266
267
268
269
};

class COO : public GraphInterface {
 public:
  // Create a coo graph that shares the given src and dst
270
271
  COO(int64_t num_vertices, IdArray src, IdArray dst,
      bool row_sorted = false, bool col_sorted = false);
272
273
274
275

  // TODO(da): add constructor for creating COO from shared memory

  void AddVertices(uint64_t num_vertices) override {
276
    LOG(FATAL) << "COO graph does not allow mutation.";
277
278
279
  }

  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
280
    LOG(FATAL) << "COO graph does not allow mutation.";
281
282
283
  }

  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
284
    LOG(FATAL) << "COO graph does not allow mutation.";
285
286
287
  }

  void Clear() override {
288
    LOG(FATAL) << "COO graph does not allow mutation.";
289
290
  }

291
  DGLContext Context() const override {
292
    return adj_.row->ctx;
293
294
295
  }

  uint8_t NumBits() const override {
296
    return adj_.row->dtype.bits;
297
298
  }

299
300
301
302
303
304
305
  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices() const override {
306
    return adj_.num_rows;
307
308
309
  }

  uint64_t NumEdges() const override {
310
    return adj_.row->shape[0];
311
312
313
  }

  bool HasVertex(dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
314
    return vid < NumVertices();
315
316
  }

317
318
319
320
321
  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

322
323
324
325
326
327
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return false;
  }

328
329
330
331
332
333
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Predecessors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Successors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

358
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;
359
360
361
362
363
364
365
366

  EdgeArray FindEdges(IdArray eids) const override;

  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray Edges(const std::string &order = "") const override;

  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  uint64_t OutDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray OutDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  Subgraph VertexSubgraph(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

418
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446

  DGLIdIters SuccVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient SuccVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient PredVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
447
      return {aten::HStack(adj_.col, adj_.row)};
448
    } else {
449
      return {aten::HStack(adj_.row, adj_.col)};
450
    }
451
  }
452

453
454
  /*! \brief Return the transpose of this COO */
  COOPtr Transpose() const {
455
    return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
456
  }
457

458
459
  /*! \brief Convert this COO to CSR */
  CSRPtr ToCSR() const;
460

461
462
463
464
465
  /*!
   * \brief Get the coo matrix that represents this graph.
   * \note The coo matrix shares the storage with this graph.
   *       The data field of the coo matrix is none.
   */
466
467
  aten::COOMatrix ToCOOMatrix() const {
    return adj_;
468
469
  }

470
471
472
473
474
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
475
  COO CopyTo(const DGLContext& ctx) const;
476

477
478
479
480
481
482
483
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  COO CopyToSharedMem(const std::string &name) const;

484
485
486
487
488
489
490
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  COO AsNumBits(uint8_t bits) const;

491
492
493
494
495
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return false;
  }

496
497
  // member getters

498
  IdArray src() const { return adj_.row; }
499

500
  IdArray dst() const { return adj_.col; }
501
502
503
504
505

 private:
  /* !\brief private default constructor */
  COO() {}

506
507
508
  // The internal COO adjacency matrix.
  // The data field is empty
  aten::COOMatrix adj_;
509
510
511
512
513
514
515
516
517
518
519
};

/*!
 * \brief DGL immutable graph index class.
 *
 * DGL's graph is directed. Vertices are integers enumerated from zero.
 */
class ImmutableGraph: public GraphInterface {
 public:
  /*! \brief Construct an immutable graph from the COO format. */
  explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
520

521
522
523
524
525
526
527
528
529
530
531
532
533
  /*!
   * \brief Construct an immutable graph from the CSR format.
   *
   * For a single graph, we need two CSRs, one stores the in-edges of vertices and
   * the other stores the out-edges of vertices. These two CSRs stores the same edges.
   * The reason we need both is that some operators are faster on in-edge CSR and
   * the other operators are faster on out-edge CSR.
   *
   * However, not both CSRs are required. Technically, one CSR contains all information.
   * Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
   * construct one of the CSRs that runs fast for some operations we expect and construct
   * the other CSR on demand.
   */
534
535
536
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
537
538
  }

539
540
  /*! \brief Construct an immutable graph from one CSR. */
  explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }
541
542
543
544
545
546
547
548
549
550
551

  /*! \brief default copy constructor */
  ImmutableGraph(const ImmutableGraph& other) = default;

#ifndef _MSC_VER
  /*! \brief default move constructor */
  ImmutableGraph(ImmutableGraph&& other) = default;
#else
  ImmutableGraph(ImmutableGraph&& other) {
    this->in_csr_ = other.in_csr_;
    this->out_csr_ = other.out_csr_;
552
    this->coo_ = other.coo_;
553
554
    other.in_csr_ = nullptr;
    other.out_csr_ = nullptr;
555
    other.coo_ = nullptr;
556
557
558
559
560
561
562
563
564
  }
#endif  // _MSC_VER

  /*! \brief default assign constructor */
  ImmutableGraph& operator=(const ImmutableGraph& other) = default;

  /*! \brief default destructor */
  ~ImmutableGraph() = default;

565
  void AddVertices(uint64_t num_vertices) override {
566
567
568
    LOG(FATAL) << "AddVertices isn't supported in ImmutableGraph";
  }

569
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
570
571
572
    LOG(FATAL) << "AddEdge isn't supported in ImmutableGraph";
  }

573
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
574
575
576
    LOG(FATAL) << "AddEdges isn't supported in ImmutableGraph";
  }

577
  void Clear() override {
578
579
580
    LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
  }

581
  DGLContext Context() const override {
582
583
584
585
586
587
588
    return AnyGraph()->Context();
  }

  uint8_t NumBits() const override {
    return AnyGraph()->NumBits();
  }

589
590
591
592
  /*!
   * \note not const since we have caches
   * \return whether the graph is a multigraph
   */
593
594
  bool IsMultigraph() const override {
    return AnyGraph()->IsMultigraph();
595
596
597
598
599
  }

  /*!
   * \return whether the graph is read-only
   */
600
  bool IsReadonly() const override {
601
602
603
604
    return true;
  }

  /*! \return the number of vertices in the graph.*/
605
606
  uint64_t NumVertices() const override {
    return AnyGraph()->NumVertices();
607
608
609
  }

  /*! \return the number of edges in the graph.*/
610
611
  uint64_t NumEdges() const override {
    return AnyGraph()->NumEdges();
612
613
614
  }

  /*! \return true if the given vertex is in the graph.*/
615
  bool HasVertex(dgl_id_t vid) const override {
616
617
618
    return vid < NumVertices();
  }

619
620
  BoolArray HasVertices(IdArray vids) const override;

621
  /*! \return true if the given edge is in the graph.*/
622
623
624
625
626
627
628
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgeBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgeBetween(src, dst);
    }
  }
629

630
631
632
633
634
635
636
637
  BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgesBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgesBetween(src, dst);
    }
  }

638
639
640
641
642
643
  /*!
   * \brief Find the predecessors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the predecessor id array.
   */
644
645
646
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetInCSR()->Successors(vid, radius);
  }
647
648
649
650
651
652
653

  /*!
   * \brief Find the successors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the successor id array.
   */
654
655
656
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetOutCSR()->Successors(vid, radius);
  }
657
658
659
660
661
662
663
664
665

  /*!
   * \brief Get all edge ids between the two given endpoints
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   * \param src The source vertex.
   * \param dst The destination vertex.
   * \return the edge id array.
   */
666
667
668
669
670
671
672
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->EdgeId(dst, src);
    } else {
      return GetOutCSR()->EdgeId(src, dst);
    }
  }
673
674
675
676
677
678
679
680
681
682

  /*!
   * \brief Get all edge ids between the given endpoint pairs.
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   *       If duplicate pairs exist, the returned edge IDs will also duplicate.
   *       The order of returned edge IDs will follow the order of src-dst pairs
   *       first, and ties are broken by the order of edge ID.
   * \return EdgeArray containing all edges between all pairs.
   */
683
684
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    if (in_csr_) {
685
686
      EdgeArray edges = in_csr_->EdgeIds(dst, src);
      return EdgeArray{edges.dst, edges.src, edges.id};
687
688
689
690
    } else {
      return GetOutCSR()->EdgeIds(src, dst);
    }
  }
691
692
693
694
695
696

  /*!
   * \brief Find the edge ID and return the pair of endpoints
   * \param eid The edge ID
   * \return a pair whose first element is the source and the second the destination.
   */
697
698
699
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    return GetCOO()->FindEdge(eid);
  }
700
701
702
703
704
705

  /*!
   * \brief Find the edge IDs and return their source and target node IDs.
   * \param eids The edge ID array.
   * \return EdgeArray containing all edges with id in eid.  The order is preserved.
   */
706
707
708
  EdgeArray FindEdges(IdArray eids) const override {
    return GetCOO()->FindEdges(eids);
  }
709
710
711
712
713
714
715

  /*!
   * \brief Get the in edges of the vertex.
   * \note The returned dst id array is filled with vid.
   * \param vid The vertex id.
   * \return the edges
   */
716
717
718
  EdgeArray InEdges(dgl_id_t vid) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vid);
    return {ret.dst, ret.src, ret.id};
719
720
721
722
723
724
725
  }

  /*!
   * \brief Get the in edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
726
727
728
  EdgeArray InEdges(IdArray vids) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vids);
    return {ret.dst, ret.src, ret.id};
729
730
731
732
733
734
735
736
  }

  /*!
   * \brief Get the out edges of the vertex.
   * \note The returned src id array is filled with vid.
   * \param vid The vertex id.
   * \return the id arrays of the two endpoints of the edges.
   */
737
738
  EdgeArray OutEdges(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdges(vid);
739
740
741
742
743
744
745
  }

  /*!
   * \brief Get the out edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
746
747
  EdgeArray OutEdges(IdArray vids) const override {
    return GetOutCSR()->OutEdges(vids);
748
749
750
751
752
753
754
755
756
  }

  /*!
   * \brief Get all the edges in the graph.
   * \note If sorted is true, the returned edges list is sorted by their src and
   *       dst ids. Otherwise, they are in their edge id order.
   * \param sorted Whether the returned edge list is sorted by their src and dst ids
   * \return the id arrays of the two endpoints of the edges.
   */
757
  EdgeArray Edges(const std::string &order = "") const override;
758
759
760
761
762
763

  /*!
   * \brief Get the in degree of the given vertex.
   * \param vid The vertex id.
   * \return the in degree
   */
764
765
  uint64_t InDegree(dgl_id_t vid) const override {
    return GetInCSR()->OutDegree(vid);
766
767
768
769
770
771
772
  }

  /*!
   * \brief Get the in degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the in degree array
   */
773
774
  DegreeArray InDegrees(IdArray vids) const override {
    return GetInCSR()->OutDegrees(vids);
775
776
777
778
779
780
781
  }

  /*!
   * \brief Get the out degree of the given vertex.
   * \param vid The vertex id.
   * \return the out degree
   */
782
783
  uint64_t OutDegree(dgl_id_t vid) const override {
    return GetOutCSR()->OutDegree(vid);
784
785
786
787
788
789
790
  }

  /*!
   * \brief Get the out degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the out degree array
   */
791
792
  DegreeArray OutDegrees(IdArray vids) const override {
    return GetOutCSR()->OutDegrees(vids);
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
  }

  /*!
   * \brief Construct the induced subgraph of the given vertices.
   *
   * The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
   * selecting all of the edges from the original graph that connect two vertices in V'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the vertices preserve the order of the given id array, while the local index
   * of the edges preserve the index order in the original graph. Vertices not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param vids The vertices in the subgraph.
   * \return the induced subgraph
   */
811
  Subgraph VertexSubgraph(IdArray vids) const override;
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828

  /*!
   * \brief Construct the induced edge subgraph of the given edges.
   *
   * The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
   * selecting all of the nodes from the original graph that are endpoints in E'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the edges preserve the order of the given id array, while the local index
   * of the vertices preserve the index order in the original graph. Edges not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param eids The edges in the subgraph.
   * \return the induced edge subgraph
   */
829
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
830
831
832
833
834
835

  /*!
   * \brief Return the successor vector
   * \param vid The vertex id.
   * \return the successor vector
   */
836
837
  DGLIdIters SuccVec(dgl_id_t vid) const override {
    return GetOutCSR()->SuccVec(vid);
838
839
840
841
842
843
844
  }

  /*!
   * \brief Return the out edge id vector
   * \param vid The vertex id.
   * \return the out edge id vector
   */
845
846
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdgeVec(vid);
847
848
849
850
851
852
853
  }

  /*!
   * \brief Return the predecessor vector
   * \param vid The vertex id.
   * \return the predecessor vector
   */
854
855
  DGLIdIters PredVec(dgl_id_t vid) const override {
    return GetInCSR()->SuccVec(vid);
856
857
858
859
860
861
862
  }

  /*!
   * \brief Return the in edge id vector
   * \param vid The vertex id.
   * \return the in edge id vector
   */
863
864
  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    return GetInCSR()->OutEdgeVec(vid);
865
866
867
868
869
870
871
872
873
874
875
  }

  /*!
   * \brief Get the adjacency matrix of the graph.
   *
   * By default, a row of returned adjacency matrix represents the destination
   * of an edge and the column represents the source.
   * \param transpose A flag to transpose the returned adjacency matrix.
   * \param fmt the format of the returned adjacency matrix.
   * \return a vector of three IdArray.
   */
876
  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
877

878
  /* !\brief Return in csr. If not exist, transpose the other one.*/
879
  CSRPtr GetInCSR() const;
880
881

  /* !\brief Return out csr. If not exist, transpose the other one.*/
882
  CSRPtr GetOutCSR() const;
883

884
  /* !\brief Return coo. If not exist, create from csr.*/
885
  COOPtr GetCOO() const;
886

887
888
889
890
  /*! \brief Create an immutable graph from CSR. */
  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);

891
  static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);
892
893
894

  /*! \brief Create an immutable graph from COO. */
  static ImmutableGraphPtr CreateFromCOO(
895
896
      int64_t num_vertices, IdArray src, IdArray dst,
      bool row_osrted = false, bool col_sorted = false);
897

898
899
900
901
902
903
904
905
906
  /*!
   * \brief Convert the given graph to an immutable graph.
   *
   * If the graph is already an immutable graph. The result graph will share
   * the storage with the given one.
   *
   * \param graph The input graph.
   * \return an immutable graph object.
   */
907
  static ImmutableGraphPtr ToImmutable(GraphPtr graph);
908
909
910
911
912
913

  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
914
  static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx);
915

916
917
918
919
920
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
921
  static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name);
922

923
924
925
926
927
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
928
  static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);
929

930
931
932
933
934
935
936
937
  /*!
   * \brief Return a new graph with all the edges reversed.
   *
   * The returned graph preserves the vertex and edge index in the original graph.
   *
   * \return the reversed graph
   */
  ImmutableGraphPtr Reverse() const;
938

939
940
  /*! \return Load ImmutableGraph from stream, using out csr */
  bool Load(dmlc::Stream *fs);
941

942
  /*! \return Save ImmutableGraph to stream, using out csr */
943
944
  void Save(dmlc::Stream* fs) const;

945
  void SortCSR() override {
Da Zheng's avatar
Da Zheng committed
946
947
948
949
    GetInCSR()->SortCSR();
    GetOutCSR()->SortCSR();
  }

950
951
952
953
954
955
956
957
  bool HasInCSR() const {
    return in_csr_ != NULL;
  }

  bool HasOutCSR() const {
    return out_csr_ != NULL;
  }

958
959
960
  /*! \brief Cast this graph to a heterograph */
  HeteroGraphPtr AsHeteroGraph() const;

961
 protected:
962
  friend class Serializer;
963
  friend class UnitGraph;
964

965
966
  /* !\brief internal default constructor */
  ImmutableGraph() {}
967

968
969
970
971
972
  /* !\brief internal constructor for all the members */
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
    : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
    CHECK(AnyGraph()) << "At least one graph structure should exist.";
  }
973

974
975
976
977
978
979
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
    this->shared_mem_name_ = shared_mem_name;
  }

980
981
982
983
984
985
986
987
988
989
  /* !\brief return pointer to any available graph structure */
  GraphPtr AnyGraph() const {
    if (in_csr_) {
      return in_csr_;
    } else if (out_csr_) {
      return out_csr_;
    } else {
      return coo_;
    }
  }
990

991
992
993
994
995
996
  // Store the in csr (i.e, the reverse csr)
  CSRPtr in_csr_;
  // Store the out csr (i.e, the normal csr)
  CSRPtr out_csr_;
  // Store the edge list indexed by edge id (COO)
  COOPtr coo_;
997
998
999
1000

  // The name of shared memory for this graph.
  // If it's empty, the graph isn't stored in shared memory.
  std::string shared_mem_name_;
1001
1002
  // We serialize the metadata of the graph index here for shared memory.
  NDArray serialized_shared_meta_;
1003
1004
};

1005
1006
1007
1008
// inline implementations

template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
1009
    IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
  // TODO(minjie): this should be changed to a device-agnostic implementation
  //   in the future
  adj_.num_rows = num_vertices;
  adj_.num_cols = num_vertices;
  adj_.indptr = aten::NewIdArray(num_vertices + 1);
  adj_.indices = aten::NewIdArray(num_edges);
  adj_.data = aten::NewIdArray(num_edges);
  dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
1020
1021
1022
1023
1024
1025
1026
1027
  for (int64_t i = 0; i < num_vertices + 1; ++i)
    *(indptr_data++) = *(indptr_begin++);
  for (int64_t i = 0; i < num_edges; ++i) {
    *(indices_data++) = *(indices_begin++);
    *(edge_ids_data++) = *(edge_ids_begin++);
  }
}

1028
1029
}  // namespace dgl

1030
namespace dmlc {
1031
DMLC_DECLARE_TRAITS(has_saveload, dgl::CSR, true);
1032
1033
1034
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
}  // namespace dmlc

1035
#endif  // DGL_IMMUTABLE_GRAPH_H_