immutable_graph.h 32.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file dgl/immutable_graph.h
 * \brief DGL immutable graph index class.
 */
#ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_

#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
14
#include <algorithm>
15
16
#include "runtime/ndarray.h"
#include "graph_interface.h"
17
#include "lazy.h"
18
#include "base_heterograph.h"
19
20
21

namespace dgl {

22
23
24
25
26
class CSR;
class COO;
typedef std::shared_ptr<CSR> CSRPtr;
typedef std::shared_ptr<COO> COOPtr;

27
28
29
class ImmutableGraph;
typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;

30
/*!
31
 * \brief Graph class stored using CSR structure.
32
 */
33
class CSR : public GraphInterface {
34
 public:
35
36
37
38
39
40
  // Create a csr graph that has the given number of verts and edges.
  CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph);
  // Create a csr graph whose memory is stored in the shared memory
  //   that has the given number of verts and edges.
  CSR(const std::string &shared_mem_name,
      int64_t num_vertices, int64_t num_edges, bool is_multigraph);
41

42
43
44
  // Create a csr graph that shares the given indptr and indices.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);
45
46
47
48
49
50
51

  // Create a csr graph by data iterator
  template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
  CSR(int64_t num_vertices, int64_t num_edges,
      IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
      bool is_multigraph);

52
53
54
55
56
57
58
59
60
61
  // Create a csr graph whose memory is stored in the shared memory
  //   and the structure is given by the indptr and indcies.
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &shared_mem_name);
  CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
      const std::string &shared_mem_name);

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
62

63
64
65
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
66

67
68
69
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
70

71
72
73
  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }
74

75
  DLContext Context() const override {
76
    return adj_.indptr->ctx;
77
78
79
  }

  uint8_t NumBits() const override {
80
    return adj_.indices->dtype.bits;
81
82
  }

83
  bool IsMultigraph() const override;
84

85
86
87
  bool IsReadonly() const override {
    return true;
  }
88

89
  uint64_t NumVertices() const override {
90
    return adj_.indptr->shape[0] - 1;
91
  }
92

93
  uint64_t NumEdges() const override {
94
95
96
97
98
99
    return adj_.indices->shape[0];
  }

  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return {};
100
  }
101

102
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
103

104
105
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;

106
107
108
109
110
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "CSR graph does not support efficient predecessor query."
      << " Please use successors on the reverse CSR graph.";
    return {};
  }
111

112
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
113

114
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override;
115

116
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override;
117

118
119
120
121
122
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdge."
      << " Please use COO graph.";
    return {};
  }
123

124
125
126
127
128
  EdgeArray FindEdges(IdArray eids) const override {
    LOG(FATAL) << "CSR graph does not support efficient FindEdges."
      << " Please use COO graph.";
    return {};
  }
129

130
131
132
133
134
  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
135

136
137
138
139
140
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient inedges query."
      << " Please use outedges on the reverse CSR graph.";
    return {};
  }
141

142
  EdgeArray OutEdges(dgl_id_t vid) const override;
143

144
  EdgeArray OutEdges(IdArray vids) const override;
145

146
  EdgeArray Edges(const std::string &order = "") const override;
147

148
149
150
151
152
  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return 0;
  }
153

154
155
156
157
158
  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "CSR graph does not support efficient indegree query."
      << " Please use outdegree on the reverse CSR graph.";
    return {};
  }
Da Zheng's avatar
Da Zheng committed
159

160
  uint64_t OutDegree(dgl_id_t vid) const override {
161
    return aten::CSRGetRowNNZ(adj_, vid);
162
163
164
165
166
167
  }

  DegreeArray OutDegrees(IdArray vids) const override;

  Subgraph VertexSubgraph(IdArray vids) const override;

168
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
169
170
171
172
173
    LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
      << " Please use COO graph instead.";
    return {};
  }

174
  DGLIdIters SuccVec(dgl_id_t vid) const override;
175

176
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient PredVec."
      << " Please use SuccVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
      << " Please use OutEdgeVec on the reverse CSR graph.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
192
    return {adj_.indptr, adj_.indices, adj_.data};
193
194
  }

195
196
197
198
199
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return !shared_mem_name_.empty();
  }

200
201
202
203
204
205
206
207
208
209
210
  /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
  CSRPtr Transpose() const;

  /*! \brief Convert this CSR to COO */
  COOPtr ToCOO() const;

  /*!
   * \return the csr matrix that represents this graph.
   * \note The csr matrix shares the storage with this graph.
   *       The data field of the CSR matrix stores the edge ids.
   */
211
212
  aten::CSRMatrix ToCSRMatrix() const {
    return adj_;
213
214
  }

215
216
217
218
219
220
221
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  CSR CopyTo(const DLContext& ctx) const;

222
223
224
225
226
227
228
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  CSR CopyToSharedMem(const std::string &name) const;

229
230
231
232
233
234
235
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  CSR AsNumBits(uint8_t bits) const;

236
237
  // member getters

238
  IdArray indptr() const { return adj_.indptr; }
239

240
  IdArray indices() const { return adj_.indices; }
241

242
  IdArray edge_ids() const { return adj_.data; }
243

244
245
246
247
248
249
  /*! \return Load CSR from stream */
  bool Load(dmlc::Stream *fs);

  /*! \return Save CSR to stream */
  void Save(dmlc::Stream* fs) const;

250
  void SortCSR() override {
Da Zheng's avatar
Da Zheng committed
251
252
    if (adj_.sorted)
      return;
253
    aten::CSRSort_(&adj_);
Da Zheng's avatar
Da Zheng committed
254
255
  }

256
 private:
257
  friend class Serializer;
258

259
260
  /*! \brief private default constructor */
  CSR() {adj_.sorted = false;}
261
262
263
  // The internal CSR adjacency matrix.
  // The data field stores edge ids.
  aten::CSRMatrix adj_;
264
265

  // whether the graph is a multi-graph
266
  Lazy<bool> is_multigraph_;
267
268
269
270

  // The name of the shared memory to store data.
  // If it's empty, data isn't stored in shared memory.
  std::string shared_mem_name_;
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
};

class COO : public GraphInterface {
 public:
  // Create a coo graph that shares the given src and dst
  COO(int64_t num_vertices, IdArray src, IdArray dst);
  COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph);

  // TODO(da): add constructor for creating COO from shared memory

  void AddVertices(uint64_t num_vertices) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

  void Clear() override {
    LOG(FATAL) << "CSR graph does not allow mutation.";
  }

297
  DLContext Context() const override {
298
    return adj_.row->ctx;
299
300
301
  }

  uint8_t NumBits() const override {
302
    return adj_.row->dtype.bits;
303
304
  }

305
306
307
308
309
310
311
  bool IsMultigraph() const override;

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices() const override {
312
    return adj_.num_rows;
313
314
315
  }

  uint64_t NumEdges() const override {
316
    return adj_.row->shape[0];
317
318
319
  }

  bool HasVertex(dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
320
    return vid < NumVertices();
321
322
  }

323
324
325
326
327
  BoolArray HasVertices(IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

328
329
330
331
332
333
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return false;
  }

334
335
336
337
338
339
  BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
    LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Predecessors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    LOG(FATAL) << "COO graph does not support efficient Successors."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    LOG(FATAL) << "COO graph does not support efficient EdgeId."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

364
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;
365
366
367
368
369
370
371
372

  EdgeArray FindEdges(IdArray eids) const override;

  EdgeArray InEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
  EdgeArray InEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray OutEdges(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdges."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  EdgeArray Edges(const std::string &order = "") const override;

  uint64_t InDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray InDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient InDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  uint64_t OutDegree(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegree."
      << " Please use CSR graph or AdjList graph instead.";
    return 0;
  }

  DegreeArray OutDegrees(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient OutDegrees."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

  Subgraph VertexSubgraph(IdArray vids) const override {
    LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
      << " Please use CSR graph or AdjList graph instead.";
    return {};
  }

424
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

  DGLIdIters SuccVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient SuccVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters PredVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient PredVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
      << " Please use CSR graph or AdjList graph instead.";
    return DGLIdIters(nullptr, nullptr);
  }

  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
453
      return {aten::HStack(adj_.col, adj_.row)};
454
    } else {
455
      return {aten::HStack(adj_.row, adj_.col)};
456
    }
457
  }
458

459
460
  /*! \brief Return the transpose of this COO */
  COOPtr Transpose() const {
461
    return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
462
  }
463

464
465
  /*! \brief Convert this COO to CSR */
  CSRPtr ToCSR() const;
466

467
468
469
470
471
  /*!
   * \brief Get the coo matrix that represents this graph.
   * \note The coo matrix shares the storage with this graph.
   *       The data field of the coo matrix is none.
   */
472
473
  aten::COOMatrix ToCOOMatrix() const {
    return adj_;
474
475
  }

476
477
478
479
480
481
482
  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
  COO CopyTo(const DLContext& ctx) const;

483
484
485
486
487
488
489
  /*!
   * \brief Copy data to shared memory.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
  COO CopyToSharedMem(const std::string &name) const;

490
491
492
493
494
495
496
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
  COO AsNumBits(uint8_t bits) const;

497
498
499
500
501
  /*! \brief Indicate whether this uses shared memory. */
  bool IsSharedMem() const {
    return false;
  }

502
503
  // member getters

504
  IdArray src() const { return adj_.row; }
505

506
  IdArray dst() const { return adj_.col; }
507
508
509
510
511

 private:
  /* !\brief private default constructor */
  COO() {}

512
513
514
515
  // The internal COO adjacency matrix.
  // The data field is empty
  aten::COOMatrix adj_;

516
  /*! \brief whether the graph is a multi-graph */
517
  Lazy<bool> is_multigraph_;
518
519
520
521
522
523
524
525
526
527
528
};

/*!
 * \brief DGL immutable graph index class.
 *
 * DGL's graph is directed. Vertices are integers enumerated from zero.
 */
class ImmutableGraph: public GraphInterface {
 public:
  /*! \brief Construct an immutable graph from the COO format. */
  explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
529

530
531
532
533
534
535
536
537
538
539
540
541
542
  /*!
   * \brief Construct an immutable graph from the CSR format.
   *
   * For a single graph, we need two CSRs, one stores the in-edges of vertices and
   * the other stores the out-edges of vertices. These two CSRs stores the same edges.
   * The reason we need both is that some operators are faster on in-edge CSR and
   * the other operators are faster on out-edge CSR.
   *
   * However, not both CSRs are required. Technically, one CSR contains all information.
   * Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
   * construct one of the CSRs that runs fast for some operations we expect and construct
   * the other CSR on demand.
   */
543
544
545
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
546
547
  }

548
549
  /*! \brief Construct an immutable graph from one CSR. */
  explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }
550
551
552
553
554
555
556
557
558
559
560

  /*! \brief default copy constructor */
  ImmutableGraph(const ImmutableGraph& other) = default;

#ifndef _MSC_VER
  /*! \brief default move constructor */
  ImmutableGraph(ImmutableGraph&& other) = default;
#else
  ImmutableGraph(ImmutableGraph&& other) {
    this->in_csr_ = other.in_csr_;
    this->out_csr_ = other.out_csr_;
561
    this->coo_ = other.coo_;
562
563
    other.in_csr_ = nullptr;
    other.out_csr_ = nullptr;
564
    other.coo_ = nullptr;
565
566
567
568
569
570
571
572
573
  }
#endif  // _MSC_VER

  /*! \brief default assign constructor */
  ImmutableGraph& operator=(const ImmutableGraph& other) = default;

  /*! \brief default destructor */
  ~ImmutableGraph() = default;

574
  void AddVertices(uint64_t num_vertices) override {
575
576
577
    LOG(FATAL) << "AddVertices isn't supported in ImmutableGraph";
  }

578
  void AddEdge(dgl_id_t src, dgl_id_t dst) override {
579
580
581
    LOG(FATAL) << "AddEdge isn't supported in ImmutableGraph";
  }

582
  void AddEdges(IdArray src_ids, IdArray dst_ids) override {
583
584
585
    LOG(FATAL) << "AddEdges isn't supported in ImmutableGraph";
  }

586
  void Clear() override {
587
588
589
    LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
  }

590
591
592
593
594
595
596
597
  DLContext Context() const override {
    return AnyGraph()->Context();
  }

  uint8_t NumBits() const override {
    return AnyGraph()->NumBits();
  }

598
599
600
601
  /*!
   * \note not const since we have caches
   * \return whether the graph is a multigraph
   */
602
603
  bool IsMultigraph() const override {
    return AnyGraph()->IsMultigraph();
604
605
606
607
608
  }

  /*!
   * \return whether the graph is read-only
   */
609
  bool IsReadonly() const override {
610
611
612
613
    return true;
  }

  /*! \return the number of vertices in the graph.*/
614
615
  uint64_t NumVertices() const override {
    return AnyGraph()->NumVertices();
616
617
618
  }

  /*! \return the number of edges in the graph.*/
619
620
  uint64_t NumEdges() const override {
    return AnyGraph()->NumEdges();
621
622
623
  }

  /*! \return true if the given vertex is in the graph.*/
624
  bool HasVertex(dgl_id_t vid) const override {
625
626
627
    return vid < NumVertices();
  }

628
629
  BoolArray HasVertices(IdArray vids) const override;

630
  /*! \return true if the given edge is in the graph.*/
631
632
633
634
635
636
637
  bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgeBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgeBetween(src, dst);
    }
  }
638

639
640
641
642
643
644
645
646
  BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {
    if (in_csr_) {
      return in_csr_->HasEdgesBetween(dst, src);
    } else {
      return GetOutCSR()->HasEdgesBetween(src, dst);
    }
  }

647
648
649
650
651
652
  /*!
   * \brief Find the predecessors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the predecessor id array.
   */
653
654
655
  IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetInCSR()->Successors(vid, radius);
  }
656
657
658
659
660
661
662

  /*!
   * \brief Find the successors of a vertex.
   * \param vid The vertex id.
   * \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
   * \return the successor id array.
   */
663
664
665
  IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
    return GetOutCSR()->Successors(vid, radius);
  }
666
667
668
669
670
671
672
673
674

  /*!
   * \brief Get all edge ids between the two given endpoints
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   * \param src The source vertex.
   * \param dst The destination vertex.
   * \return the edge id array.
   */
675
676
677
678
679
680
681
  IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
    if (in_csr_) {
      return in_csr_->EdgeId(dst, src);
    } else {
      return GetOutCSR()->EdgeId(src, dst);
    }
  }
682
683
684
685
686
687
688
689
690
691

  /*!
   * \brief Get all edge ids between the given endpoint pairs.
   * \note Edges are associated with an integer id start from zero.
   *       The id is assigned when the edge is being added to the graph.
   *       If duplicate pairs exist, the returned edge IDs will also duplicate.
   *       The order of returned edge IDs will follow the order of src-dst pairs
   *       first, and ties are broken by the order of edge ID.
   * \return EdgeArray containing all edges between all pairs.
   */
692
693
  EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
    if (in_csr_) {
694
695
      EdgeArray edges = in_csr_->EdgeIds(dst, src);
      return EdgeArray{edges.dst, edges.src, edges.id};
696
697
698
699
    } else {
      return GetOutCSR()->EdgeIds(src, dst);
    }
  }
700
701
702
703
704
705

  /*!
   * \brief Find the edge ID and return the pair of endpoints
   * \param eid The edge ID
   * \return a pair whose first element is the source and the second the destination.
   */
706
707
708
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
    return GetCOO()->FindEdge(eid);
  }
709
710
711
712
713
714

  /*!
   * \brief Find the edge IDs and return their source and target node IDs.
   * \param eids The edge ID array.
   * \return EdgeArray containing all edges with id in eid.  The order is preserved.
   */
715
716
717
  EdgeArray FindEdges(IdArray eids) const override {
    return GetCOO()->FindEdges(eids);
  }
718
719
720
721
722
723
724

  /*!
   * \brief Get the in edges of the vertex.
   * \note The returned dst id array is filled with vid.
   * \param vid The vertex id.
   * \return the edges
   */
725
726
727
  EdgeArray InEdges(dgl_id_t vid) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vid);
    return {ret.dst, ret.src, ret.id};
728
729
730
731
732
733
734
  }

  /*!
   * \brief Get the in edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
735
736
737
  EdgeArray InEdges(IdArray vids) const override {
    const EdgeArray& ret = GetInCSR()->OutEdges(vids);
    return {ret.dst, ret.src, ret.id};
738
739
740
741
742
743
744
745
  }

  /*!
   * \brief Get the out edges of the vertex.
   * \note The returned src id array is filled with vid.
   * \param vid The vertex id.
   * \return the id arrays of the two endpoints of the edges.
   */
746
747
  EdgeArray OutEdges(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdges(vid);
748
749
750
751
752
753
754
  }

  /*!
   * \brief Get the out edges of the vertices.
   * \param vids The vertex id array.
   * \return the id arrays of the two endpoints of the edges.
   */
755
756
  EdgeArray OutEdges(IdArray vids) const override {
    return GetOutCSR()->OutEdges(vids);
757
758
759
760
761
762
763
764
765
  }

  /*!
   * \brief Get all the edges in the graph.
   * \note If sorted is true, the returned edges list is sorted by their src and
   *       dst ids. Otherwise, they are in their edge id order.
   * \param sorted Whether the returned edge list is sorted by their src and dst ids
   * \return the id arrays of the two endpoints of the edges.
   */
766
  EdgeArray Edges(const std::string &order = "") const override;
767
768
769
770
771
772

  /*!
   * \brief Get the in degree of the given vertex.
   * \param vid The vertex id.
   * \return the in degree
   */
773
774
  uint64_t InDegree(dgl_id_t vid) const override {
    return GetInCSR()->OutDegree(vid);
775
776
777
778
779
780
781
  }

  /*!
   * \brief Get the in degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the in degree array
   */
782
783
  DegreeArray InDegrees(IdArray vids) const override {
    return GetInCSR()->OutDegrees(vids);
784
785
786
787
788
789
790
  }

  /*!
   * \brief Get the out degree of the given vertex.
   * \param vid The vertex id.
   * \return the out degree
   */
791
792
  uint64_t OutDegree(dgl_id_t vid) const override {
    return GetOutCSR()->OutDegree(vid);
793
794
795
796
797
798
799
  }

  /*!
   * \brief Get the out degrees of the given vertices.
   * \param vid The vertex id array.
   * \return the out degree array
   */
800
801
  DegreeArray OutDegrees(IdArray vids) const override {
    return GetOutCSR()->OutDegrees(vids);
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
  }

  /*!
   * \brief Construct the induced subgraph of the given vertices.
   *
   * The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
   * selecting all of the edges from the original graph that connect two vertices in V'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the vertices preserve the order of the given id array, while the local index
   * of the edges preserve the index order in the original graph. Vertices not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param vids The vertices in the subgraph.
   * \return the induced subgraph
   */
820
  Subgraph VertexSubgraph(IdArray vids) const override;
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837

  /*!
   * \brief Construct the induced edge subgraph of the given edges.
   *
   * The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
   * selecting all of the nodes from the original graph that are endpoints in E'.
   *
   * Vertices and edges in the original graph will be "reindexed" to local index. The local
   * index of the edges preserve the order of the given id array, while the local index
   * of the vertices preserve the index order in the original graph. Edges not in the
   * original graph are ignored.
   *
   * The result subgraph is read-only.
   *
   * \param eids The edges in the subgraph.
   * \return the induced edge subgraph
   */
838
  Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
839
840
841
842
843
844

  /*!
   * \brief Return the successor vector
   * \param vid The vertex id.
   * \return the successor vector
   */
845
846
  DGLIdIters SuccVec(dgl_id_t vid) const override {
    return GetOutCSR()->SuccVec(vid);
847
848
849
850
851
852
853
  }

  /*!
   * \brief Return the out edge id vector
   * \param vid The vertex id.
   * \return the out edge id vector
   */
854
855
  DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
    return GetOutCSR()->OutEdgeVec(vid);
856
857
858
859
860
861
862
  }

  /*!
   * \brief Return the predecessor vector
   * \param vid The vertex id.
   * \return the predecessor vector
   */
863
864
  DGLIdIters PredVec(dgl_id_t vid) const override {
    return GetInCSR()->SuccVec(vid);
865
866
867
868
869
870
871
  }

  /*!
   * \brief Return the in edge id vector
   * \param vid The vertex id.
   * \return the in edge id vector
   */
872
873
  DGLIdIters InEdgeVec(dgl_id_t vid) const override {
    return GetInCSR()->OutEdgeVec(vid);
874
875
876
877
878
879
880
881
882
883
884
  }

  /*!
   * \brief Get the adjacency matrix of the graph.
   *
   * By default, a row of returned adjacency matrix represents the destination
   * of an edge and the column represents the source.
   * \param transpose A flag to transpose the returned adjacency matrix.
   * \param fmt the format of the returned adjacency matrix.
   * \return a vector of three IdArray.
   */
885
  std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
886

887
  /* !\brief Return in csr. If not exist, transpose the other one.*/
888
  CSRPtr GetInCSR() const;
889
890

  /* !\brief Return out csr. If not exist, transpose the other one.*/
891
  CSRPtr GetOutCSR() const;
892

893
  /* !\brief Return coo. If not exist, create from csr.*/
894
  COOPtr GetCOO() const;
895

896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
  /*! \brief Create an immutable graph from CSR. */
  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      bool multigraph, const std::string &edge_dir);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      const std::string &edge_dir, const std::string &shared_mem_name);

  static ImmutableGraphPtr CreateFromCSR(
      IdArray indptr, IdArray indices, IdArray edge_ids,
      bool multigraph, const std::string &edge_dir,
      const std::string &shared_mem_name);

  static ImmutableGraphPtr CreateFromCSR(
      const std::string &shared_mem_name, size_t num_vertices,
      size_t num_edges, bool multigraph,
      const std::string &edge_dir);

  /*! \brief Create an immutable graph from COO. */
  static ImmutableGraphPtr CreateFromCOO(
      int64_t num_vertices, IdArray src, IdArray dst);

  static ImmutableGraphPtr CreateFromCOO(
      int64_t num_vertices, IdArray src, IdArray dst, bool multigraph);

925
926
927
928
929
930
931
932
933
  /*!
   * \brief Convert the given graph to an immutable graph.
   *
   * If the graph is already an immutable graph. The result graph will share
   * the storage with the given one.
   *
   * \param graph The input graph.
   * \return an immutable graph object.
   */
934
  static ImmutableGraphPtr ToImmutable(GraphPtr graph);
935
936
937
938
939
940

  /*!
   * \brief Copy the data to another context.
   * \param ctx The target context.
   * \return The graph under another context.
   */
941
  static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);
942

943
944
945
946
947
948
  /*!
   * \brief Copy data to shared memory.
   * \param edge_dir the graph of the specific edge direction to be copied.
   * \param name The name of the shared memory.
   * \return The graph in the shared memory
   */
949
950
  static ImmutableGraphPtr CopyToSharedMem(
      ImmutableGraphPtr g, const std::string &edge_dir, const std::string &name);
951

952
953
954
955
956
  /*!
   * \brief Convert the graph to use the given number of bits for storage.
   * \param bits The new number of integer bits (32 or 64).
   * \return The graph with new bit size storage.
   */
957
  static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);
958

959
960
961
962
963
964
965
966
  /*!
   * \brief Return a new graph with all the edges reversed.
   *
   * The returned graph preserves the vertex and edge index in the original graph.
   *
   * \return the reversed graph
   */
  ImmutableGraphPtr Reverse() const;
967

968
969
  /*! \return Load ImmutableGraph from stream, using out csr */
  bool Load(dmlc::Stream *fs);
970

971
  /*! \return Save ImmutableGraph to stream, using out csr */
972
973
  void Save(dmlc::Stream* fs) const;

Da Zheng's avatar
Da Zheng committed
974
975
976
977
978
  void SortCSR() {
    GetInCSR()->SortCSR();
    GetOutCSR()->SortCSR();
  }

979
980
981
  /*! \brief Cast this graph to a heterograph */
  HeteroGraphPtr AsHeteroGraph() const;

982
 protected:
983
  friend class Serializer;
984
  friend class UnitGraph;
985

986
987
  /* !\brief internal default constructor */
  ImmutableGraph() {}
988

989
990
991
992
993
  /* !\brief internal constructor for all the members */
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
    : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
    CHECK(AnyGraph()) << "At least one graph structure should exist.";
  }
994

995
996
997
998
999
1000
  ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
    : in_csr_(in_csr), out_csr_(out_csr) {
    CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
    this->shared_mem_name_ = shared_mem_name;
  }

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
  /* !\brief return pointer to any available graph structure */
  GraphPtr AnyGraph() const {
    if (in_csr_) {
      return in_csr_;
    } else if (out_csr_) {
      return out_csr_;
    } else {
      return coo_;
    }
  }
1011

1012
1013
1014
1015
1016
1017
  // Store the in csr (i.e, the reverse csr)
  CSRPtr in_csr_;
  // Store the out csr (i.e, the normal csr)
  CSRPtr out_csr_;
  // Store the edge list indexed by edge id (COO)
  COOPtr coo_;
1018
1019
1020
1021

  // The name of shared memory for this graph.
  // If it's empty, the graph isn't stored in shared memory.
  std::string shared_mem_name_;
1022
1023
};

1024
1025
1026
1027
1028
1029
// inline implementations

template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
    IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
    bool is_multigraph): is_multigraph_(is_multigraph) {
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
  // TODO(minjie): this should be changed to a device-agnostic implementation
  //   in the future
  adj_.num_rows = num_vertices;
  adj_.num_cols = num_vertices;
  adj_.indptr = aten::NewIdArray(num_vertices + 1);
  adj_.indices = aten::NewIdArray(num_edges);
  adj_.data = aten::NewIdArray(num_edges);
  dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
1040
1041
1042
1043
1044
1045
1046
1047
  for (int64_t i = 0; i < num_vertices + 1; ++i)
    *(indptr_data++) = *(indptr_begin++);
  for (int64_t i = 0; i < num_edges; ++i) {
    *(indices_data++) = *(indices_begin++);
    *(edge_ids_data++) = *(edge_ids_begin++);
  }
}

1048
1049
}  // namespace dgl

1050
namespace dmlc {
1051
DMLC_DECLARE_TRAITS(has_saveload, dgl::CSR, true);
1052
1053
1054
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
}  // namespace dmlc

1055
#endif  // DGL_IMMUTABLE_GRAPH_H_