"docs/source/vscode:/vscode.git/clone" did not exist on "3aef46773332f2d5e179c0abc9d0ba428a09b6b4"
unit_graph.cc 56.7 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
Minjie Wang's avatar
Minjie Wang committed
66
    : BaseHeteroGraph(metagraph) {
67
68
69
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
70
71
72
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
        NullArray(),
        row_sorted, col_sorted};
73
  }
74

75
76
77
78
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
79
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
80
    adj_.data = aten::NullArray();
81
  }
82

83
84
85
86
87
88
89
90
91
92
93
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
94
95
  inline dgl_type_t SrcType() const {
    return 0;
96
  }
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
104
105
106
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
108
109
110
111
112
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
113
    LOG(FATAL) << "UnitGraph graph is not mutable.";
114
115
116
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
117
    LOG(FATAL) << "UnitGraph graph is not mutable.";
118
119
120
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
121
    LOG(FATAL) << "UnitGraph graph is not mutable.";
122
123
124
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
125
    LOG(FATAL) << "UnitGraph graph is not mutable.";
126
127
  }

128
129
130
131
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

132
133
134
135
  DLContext Context() const override {
    return adj_.row->ctx;
  }

136
  bool IsPinned() const override {
137
    return adj_.is_pinned;
138
139
  }

140
141
142
143
  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

144
145
146
147
148
149
150
151
152
153
154
155
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

156
157
  COO CopyTo(const DLContext &ctx,
             const DGLStreamHandle &stream = nullptr) const {
158
159
    if (Context() == ctx)
      return *this;
160
    return COO(meta_graph_, adj_.CopyTo(ctx, stream));
161
162
  }

163
164
165
166
167
168
169
170
171
172
173

  /*! \brief Pin the adj_: COOMatrix of the COO graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: COOMatrix of the COO graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

174
  bool IsMultigraph() const override {
175
    return aten::COOHasDuplicate(adj_);
176
177
178
179
180
181
182
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
183
    if (vtype == SrcType()) {
184
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
185
    } else if (vtype == DstType()) {
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
207
208
209
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
210
211
212
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
213
214
215
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
216
217
218
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
219
220
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
221
222
223
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
224
225
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
226
227
228
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
229
230
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
231
    return aten::COOGetAllData(adj_, src, dst);
232
233
  }

234
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
235
236
237
238
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
239
240
  }

241
242
243
244
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::COOGetData(adj_, src, dst);
  }

245
246
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
247
248
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
249
250
251
252
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
253
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
254
    BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
255
      "FindEdges requires the internal COO matrix not having EIDs.";
256
257
258
259
260
261
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
262
263
264
265
266
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
267
268
269
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
270
271
272
273
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
274
275
276
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
277
278
279
280
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
281
282
283
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
284
285
286
287
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
288
289
290
291
292
293
294
295
296
297
298
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
299
300
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
301
302
303
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
304
305
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
306
307
308
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
309
310
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
311
312
313
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
314
315
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

348
349
350
351
352
353
354
355
356
357
358
359
360
361
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

362
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
363
    LOG(FATAL) << "Not enabled for COO graph";
364
    return SparseFormat::kCOO;
365
366
  }

367
  dgl_format_code_t GetAllowedFormats() const override {
368
    LOG(FATAL) << "Not enabled for COO graph";
369
    return 0;
370
371
  }

372
  dgl_format_code_t GetCreatedFormats() const override {
373
374
375
376
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

377
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
378
379
380
381
382
383
384
385
386
387
388
389
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
390
391
392
393
394
395
396
397
398
399
400
401
402
403
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
404
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
405
406
407
408
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
409
410
411
412
      subg.induced_vertices.emplace_back(
          aten::Range(0, NumVertices(SrcType()), NumBits(), Context()));
      subg.induced_vertices.emplace_back(
          aten::Range(0, NumVertices(DstType()), NumBits(), Context()));
413
      subg.graph = std::make_shared<COO>(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
414
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
415
416
417
418
419
      subg.induced_edges = eids;
    }
    return subg;
  }

420
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
421
422
423
424
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

425
426
427
428
  aten::COOMatrix adj() const {
    return adj_;
  }

429
430
431
432
433
434
435
436
437
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

438
439
440
441
442
443
444
445
446
447
448
449
450
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

451
 private:
452
453
  friend class Serializer;

454
455
456
457
458
459
460
461
462
463
464
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
465
class UnitGraph::CSR : public BaseHeteroGraph {
466
 public:
Minjie Wang's avatar
Minjie Wang committed
467
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
468
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
469
    : BaseHeteroGraph(metagraph) {
470
471
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
472
473
474
    if (aten::IsValidIdArray(edge_ids))
      CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
        << "edge id arrays should have the same length as indices if not empty";
475
476
    CHECK_EQ(num_src, indptr->shape[0] - 1)
      << "number of nodes do not match the length of indptr minus 1.";
477

478
479
480
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

481
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
482
483
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
484

485
486
487
488
489
490
491
492
493
494
495
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
496
497
  inline dgl_type_t SrcType() const {
    return 0;
498
  }
Minjie Wang's avatar
Minjie Wang committed
499
500
501
502
503
504
505

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
506
507
508
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
509
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
510
511
512
513
514
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
515
    LOG(FATAL) << "UnitGraph graph is not mutable.";
516
517
518
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
519
    LOG(FATAL) << "UnitGraph graph is not mutable.";
520
521
522
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
523
    LOG(FATAL) << "UnitGraph graph is not mutable.";
524
525
526
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
527
    LOG(FATAL) << "UnitGraph graph is not mutable.";
528
529
  }

530
531
532
533
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

534
535
536
537
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

538
  bool IsPinned() const override {
539
    return adj_.is_pinned;
540
541
  }

542
543
544
545
  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

546
547
548
549
550
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
551
          meta_graph_,
552
553
554
555
556
557
558
559
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

560
561
  CSR CopyTo(const DLContext &ctx,
             const DGLStreamHandle &stream = nullptr) const {
562
563
564
    if (Context() == ctx) {
      return *this;
    } else {
565
      return CSR(meta_graph_, adj_.CopyTo(ctx, stream));
566
567
568
    }
  }

569
570
571
572
573
574
575
576
577
578
  /*! \brief Pin the adj_: CSRMatrix of the CSR graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: CSRMatrix of the CSR graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

579
  bool IsMultigraph() const override {
580
    return aten::CSRHasDuplicate(adj_);
581
582
583
584
585
586
587
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
588
    if (vtype == SrcType()) {
589
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
590
    } else if (vtype == DstType()) {
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
612
613
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
614
615
616
617
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
618
619
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
620
621
622
623
624
625
626
627
628
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
629
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
630
631
632
633
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
634
635
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
636
    return aten::CSRGetAllData(adj_, src, dst);
637
638
  }

639
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
640
641
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
642
643
644
645
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

646
647
648
649
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::CSRGetData(adj_, src, dst);
  }

650
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
651
    LOG(FATAL) << "Not enabled for CSR graph.";
652
653
654
655
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
656
    LOG(FATAL) << "Not enabled for CSR graph.";
657
658
659
660
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
661
    LOG(FATAL) << "Not enabled for CSR graph.";
662
663
664
665
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
666
    LOG(FATAL) << "Not enabled for CSR graph.";
667
668
669
670
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
671
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
672
673
674
675
676
677
678
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
679
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
680
681
682
683
684
685
686
687
688
689
690
691
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
692
693
694
695
696
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
697
698
699
700
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
701
    LOG(FATAL) << "Not enabled for CSR graph.";
702
703
704
705
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
706
    LOG(FATAL) << "Not enabled for CSR graph.";
707
708
709
710
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
711
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
712
713
714
715
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
716
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
717
718
719
720
721
722
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
723
    CHECK_EQ(NumBits(), 64);
724
725
726
727
728
729
730
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

731
732
733
734
735
736
737
738
739
740
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

741
742
743
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
744
    CHECK_EQ(NumBits(), 64);
745
746
747
748
749
750
751
752
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
753
    LOG(FATAL) << "Not enabled for CSR graph.";
754
755
756
757
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
758
    LOG(FATAL) << "Not enabled for CSR graph.";
759
760
761
762
763
764
765
766
767
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

768
769
770
771
772
773
774
775
776
777
778
779
780
781
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

782
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
783
    LOG(FATAL) << "Not enabled for CSR graph";
784
    return SparseFormat::kCSR;
785
786
  }

787
788
789
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
790
791
  }

792
  dgl_format_code_t GetCreatedFormats() const override {
793
794
795
796
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

797
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
798
799
800
801
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
802
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
803
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
804
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
Minjie Wang's avatar
Minjie Wang committed
805
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
806
807
808
809
810
811
812
813
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
814
    LOG(FATAL) << "Not enabled for CSR graph.";
815
816
817
    return {};
  }

818
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
819
820
821
822
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

823
824
825
826
  aten::CSRMatrix adj() const {
    return adj_;
  }

827
828
829
830
831
832
833
834
835
836
837
838
839
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

840
 private:
841
842
  friend class Serializer;

843
844
845
846
847
848
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
849
// unit graph implementation
850
851
852
//
//////////////////////////////////////////////////////////

853
854
855
856
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
857
DLContext UnitGraph::Context() const {
858
859
860
  return GetAny()->Context();
}

861
862
863
864
bool UnitGraph::IsPinned() const {
  return GetAny()->IsPinned();
}

Minjie Wang's avatar
Minjie Wang committed
865
uint8_t UnitGraph::NumBits() const {
866
867
868
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
869
bool UnitGraph::IsMultigraph() const {
870
  const SparseFormat fmt = SelectFormat(CSC_CODE);
871
872
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
873
874
}

Minjie Wang's avatar
Minjie Wang committed
875
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
876
  const SparseFormat fmt = SelectFormat(ALL_CODE);
877
878
879
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
880
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
881
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
882
  return ptr->NumVertices(vtype);
883
884
}

Minjie Wang's avatar
Minjie Wang committed
885
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
886
887
888
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
889
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
890
  const SparseFormat fmt = SelectFormat(ALL_CODE);
891
  const auto ptr = GetFormat(fmt);
892
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
893
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
894
  return ptr->HasVertex(vtype, vid);
895
896
}

Minjie Wang's avatar
Minjie Wang committed
897
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
898
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
899
900
901
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
902
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
903
  const SparseFormat fmt = SelectFormat(CSC_CODE);
904
  const auto ptr = GetFormat(fmt);
905
  if (fmt == SparseFormat::kCSC)
906
907
908
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
909
910
}

Minjie Wang's avatar
Minjie Wang committed
911
BoolArray UnitGraph::HasEdgesBetween(
912
    dgl_type_t etype, IdArray src, IdArray dst) const {
913
  const SparseFormat fmt = SelectFormat(CSC_CODE);
914
  const auto ptr = GetFormat(fmt);
915
  if (fmt == SparseFormat::kCSC)
916
917
918
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
919
920
}

Minjie Wang's avatar
Minjie Wang committed
921
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
922
  const SparseFormat fmt = SelectFormat(CSC_CODE);
923
  const auto ptr = GetFormat(fmt);
924
  if (fmt == SparseFormat::kCSC)
925
926
927
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
928
929
}

Minjie Wang's avatar
Minjie Wang committed
930
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
931
  const SparseFormat fmt = SelectFormat(CSR_CODE);
932
933
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
934
935
}

Minjie Wang's avatar
Minjie Wang committed
936
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
937
  const SparseFormat fmt = SelectFormat(CSR_CODE);
938
  const auto ptr = GetFormat(fmt);
939
  if (fmt == SparseFormat::kCSC)
940
941
942
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
943
944
}

945
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
946
  const SparseFormat fmt = SelectFormat(CSR_CODE);
947
  const auto ptr = GetFormat(fmt);
948
  if (fmt == SparseFormat::kCSC) {
949
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
950
951
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
952
953
954
955
956
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
957
  const SparseFormat fmt = SelectFormat(CSR_CODE);
958
959
960
961
962
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
963
964
965
  }
}

Minjie Wang's avatar
Minjie Wang committed
966
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
967
  const SparseFormat fmt = SelectFormat(COO_CODE);
968
969
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
970
971
}

Minjie Wang's avatar
Minjie Wang committed
972
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
973
  const SparseFormat fmt = SelectFormat(COO_CODE);
974
975
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
976
977
}

Minjie Wang's avatar
Minjie Wang committed
978
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
979
  const SparseFormat fmt = SelectFormat(CSC_CODE);
980
  const auto ptr = GetFormat(fmt);
981
  if (fmt == SparseFormat::kCSC) {
982
983
984
985
986
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
987
988
}

Minjie Wang's avatar
Minjie Wang committed
989
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
990
  const SparseFormat fmt = SelectFormat(CSC_CODE);
991
  const auto ptr = GetFormat(fmt);
992
  if (fmt == SparseFormat::kCSC) {
993
994
995
996
997
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
998
999
}

Minjie Wang's avatar
Minjie Wang committed
1000
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
1001
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1002
1003
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
1004
1005
}

Minjie Wang's avatar
Minjie Wang committed
1006
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
1007
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1008
1009
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
1010
1011
}

Minjie Wang's avatar
Minjie Wang committed
1012
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
1013
1014
  SparseFormat fmt;
  if (order == std::string("eid")) {
1015
    fmt = SelectFormat(COO_CODE);
1016
  } else if (order.empty()) {
1017
    // arbitrary order
1018
    fmt = SelectFormat(ALL_CODE);
1019
  } else if (order == std::string("srcdst")) {
1020
    fmt = SelectFormat(CSR_CODE);
1021
1022
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
1023
    return {};
1024
  }
1025
1026

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
1027
  if (fmt == SparseFormat::kCSC)
1028
1029
1030
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1031
1032
}

Minjie Wang's avatar
Minjie Wang committed
1033
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1034
  SparseFormat fmt = SelectFormat(CSC_CODE);
1035
  const auto ptr = GetFormat(fmt);
1036
1037
1038
1039
1040
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1041
1042
}

Minjie Wang's avatar
Minjie Wang committed
1043
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1044
  SparseFormat fmt = SelectFormat(CSC_CODE);
1045
  const auto ptr = GetFormat(fmt);
1046
1047
1048
1049
1050
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1051
1052
}

Minjie Wang's avatar
Minjie Wang committed
1053
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1054
  SparseFormat fmt = SelectFormat(CSR_CODE);
1055
  const auto ptr = GetFormat(fmt);
1056
1057
1058
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1059
  return ptr->OutDegree(etype, vid);
1060
1061
}

Minjie Wang's avatar
Minjie Wang committed
1062
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1063
  SparseFormat fmt = SelectFormat(CSR_CODE);
1064
  const auto ptr = GetFormat(fmt);
1065
1066
1067
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1068
  return ptr->OutDegrees(etype, vids);
1069
1070
}

Minjie Wang's avatar
Minjie Wang committed
1071
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1072
  SparseFormat fmt = SelectFormat(CSR_CODE);
1073
1074
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1075
1076
}

1077
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1078
  SparseFormat fmt = SelectFormat(CSR_CODE);
1079
1080
1081
1082
1083
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1084
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1085
  SparseFormat fmt = SelectFormat(CSR_CODE);
1086
1087
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1088
1089
}

Minjie Wang's avatar
Minjie Wang committed
1090
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1091
  SparseFormat fmt = SelectFormat(CSC_CODE);
1092
  const auto ptr = GetFormat(fmt);
1093
  if (fmt == SparseFormat::kCSC)
1094
1095
1096
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1097
1098
}

Minjie Wang's avatar
Minjie Wang committed
1099
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1100
  SparseFormat fmt = SelectFormat(CSC_CODE);
1101
  const auto ptr = GetFormat(fmt);
1102
  if (fmt == SparseFormat::kCSC)
1103
1104
1105
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1106
1107
}

Minjie Wang's avatar
Minjie Wang committed
1108
std::vector<IdArray> UnitGraph::GetAdj(
1109
1110
1111
1112
1113
1114
1115
1116
1117
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
1118
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1119
1120
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
1121
    return GetCOO()->GetAdj(etype, transpose, fmt);
1122
1123
1124
1125
1126
1127
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
1128
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
1129
  // We prefer to generate a subgraph from out-csr.
1130
  SparseFormat fmt = SelectFormat(CSR_CODE);
1131
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1132
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1153
1154
1155
1156
1157
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1158
HeteroSubgraph UnitGraph::EdgeSubgraph(
1159
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1160
  SparseFormat fmt = SelectFormat(COO_CODE);
1161
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1162
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1183
1184
1185
1186
1187
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1188
HeteroGraphPtr UnitGraph::CreateFromCOO(
1189
1190
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1191
    bool row_sorted, bool col_sorted,
1192
    dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1193
1194
1195
1196
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1197
1198
  COOPtr coo(new COO(mg, num_src, num_dst, row, col,
      row_sorted, col_sorted));
1199
1200

  return HeteroGraphPtr(
1201
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1202
1203
}

1204
1205
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
1206
    dgl_format_code_t formats) {
1207
1208
1209
1210
1211
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1212

1213
  return HeteroGraphPtr(
1214
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1215
1216
}

Minjie Wang's avatar
Minjie Wang committed
1217
1218
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1219
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1220
1221
1222
1223
1224
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1225
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1226
1227
}

1228
1229
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1230
    dgl_format_code_t formats) {
1231
1232
1233
1234
1235
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1236
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1237
1238
}

1239
1240
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1241
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1242
1243
1244
1245
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1246
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1247
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1248
1249
1250
1251
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1252
    dgl_format_code_t formats) {
1253
1254
1255
1256
1257
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1258
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1259
1260
}

Minjie Wang's avatar
Minjie Wang committed
1261
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1262
1263
1264
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1265
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1266
    CHECK_NOTNULL(bg);
1267
1268
1269
1270
1271
1272
    CSRPtr new_incsr =
      (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
    CSRPtr new_outcsr =
      (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
    COOPtr new_coo =
      (bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
1273
    return HeteroGraphPtr(
1274
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1275
1276
1277
  }
}

1278
1279
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
                                 const DGLStreamHandle &stream) {
1280
1281
  if (ctx == g->Context()) {
    return g;
1282
1283
1284
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1285
1286
1287
1288
1289
1290
1291
1292
1293
    CSRPtr new_incsr = (bg->in_csr_->defined())
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream)))
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream)))
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream)))
                         : nullptr;
1294
    return HeteroGraphPtr(
1295
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1296
1297
1298
  }
}

1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
void UnitGraph::PinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->PinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->PinMemory_();
  if (this->coo_->defined())
    this->coo_->PinMemory_();
}

void UnitGraph::UnpinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->UnpinMemory_();
  if (this->coo_->defined())
    this->coo_->UnpinMemory_();
}

1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
void UnitGraph::InvalidateCSR() {
  this->out_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCSC() {
  this->in_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCOO() {
  this->coo_ = COOPtr(new COO());
}

1329
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
1330
                     dgl_format_code_t formats)
Minjie Wang's avatar
Minjie Wang committed
1331
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1332
1333
1334
1335
1336
1337
1338
1339
1340
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1341
1342
1343
1344
1345
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) <<
      ", which is not compatible with available formats: " << CodeToStr(formats);
1346
1347
1348
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1349
1350
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
    int num_vtypes,
1351
1352
1353
1354
1355
1356
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
1357
    dgl_format_code_t formats) {
1358
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1359
1360
1361
1362
1363
1364
1365

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1366
1367
  else
    in_csr_ptr = CSRPtr(new CSR());
1368
1369
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1370
1371
  else
    out_csr_ptr = CSRPtr(new CSR());
1372
1373
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1374
1375
  else
    coo_ptr = COOPtr(new COO());
1376

1377
  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1378
1379
}

1380
1381
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1382
    if (!(formats_ & CSC_CODE))
1383
1384
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSC matrix.";
1385
  CSRPtr ret = in_csr_;
1386
1387
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1388
  if (!in_csr_->defined()) {
1389
1390
1391
1392
    // inplace new formats materialization is not allowed for pinned graphs
    if (inplace && IsPinned())
      LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
        "please create the CSC format before pinning.";
1393
1394
1395
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(
            aten::COOTranspose(coo_->adj()));
1396

1397
      if (inplace)
1398
1399
1400
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1401
    } else {
1402
1403
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1404

1405
      if (inplace)
1406
1407
1408
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1409
1410
    }
  }
1411
  return ret;
1412
1413
1414
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
1415
1416
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1417
    if (!(formats_ & CSR_CODE))
1418
1419
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSR matrix.";
1420
  CSRPtr ret = out_csr_;
1421
1422
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1423
  if (!out_csr_->defined()) {
1424
1425
1426
1427
    // inplace new formats materialization is not allowed for pinned graphs
    if (inplace && IsPinned())
      LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
        "please create the CSR format before pinning.";
1428
1429
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1430

1431
      if (inplace)
1432
1433
1434
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1435
    } else {
1436
1437
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1438

1439
      if (inplace)
1440
1441
1442
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1443
1444
    }
  }
1445
  return ret;
1446
1447
1448
}

/* !\brief Return coo. If not exist, create from csr.*/
1449
1450
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1451
    if (!(formats_ & COO_CODE))
1452
1453
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create COO matrix.";
1454
  COOPtr ret = coo_;
1455
  if (!coo_->defined()) {
1456
1457
1458
1459
    // inplace new formats materialization is not allowed for pinned graphs
    if (inplace && IsPinned())
      LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
        "please create the COO format before pinning.";
1460
    if (in_csr_->defined()) {
1461
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1462

1463
      if (inplace)
1464
1465
1466
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1467
    } else {
1468
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1469
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1470

1471
      if (inplace)
1472
1473
1474
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1475
1476
    }
  }
1477
  return ret;
1478
1479
}

1480
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1481
1482
1483
  return GetInCSR()->adj();
}

1484
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1485
1486
1487
  return GetOutCSR()->adj();
}

1488
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1489
1490
1491
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1492
HeteroGraphPtr UnitGraph::GetAny() const {
1493
  if (in_csr_->defined()) {
1494
    return in_csr_;
1495
  } else if (out_csr_->defined()) {
1496
1497
1498
1499
1500
1501
    return out_csr_;
  } else {
    return coo_;
  }
}

1502
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1503
  dgl_format_code_t ret = 0;
1504
  if (in_csr_->defined())
1505
    ret |= CSC_CODE;
1506
  if (out_csr_->defined())
1507
    ret |= CSR_CODE;
1508
  if (coo_->defined())
1509
    ret |= COO_CODE;
1510
1511
1512
  return ret;
}

1513
1514
1515
1516
dgl_format_code_t UnitGraph::GetAllowedFormats() const {
  return formats_;
}

1517
1518
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1519
1520
1521
1522
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
1523
  default:
1524
    return GetCOO();
1525
1526
1527
  }
}

1528
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1529
  if (formats == ALL_CODE)
1530
    return HeteroGraphPtr(
1531
1532
        // TODO(xiangsx) Make it as graph storage.Clone()
        new UnitGraph(meta_graph_,
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
                      (in_csr_->defined())
                          ? CSRPtr(new CSR(*in_csr_))
                          : nullptr,
                      (out_csr_->defined())
                          ? CSRPtr(new CSR(*out_csr_))
                          : nullptr,
                      (coo_->defined())
                          ? COOPtr(new COO(*coo_))
                          : nullptr,
                      formats));
  int64_t num_vtypes = NumVertexTypes();
1544
  if (formats & COO_CODE)
1545
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1546
  if (formats & CSR_CODE)
1547
1548
1549
1550
1551
1552
1553
1554
1555
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const {
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
  if (common & created)
    return DecodeFormat(common & created);
1556
1557
1558
1559
1560

  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have
  // not been implmented yet.
  // if (coo_->defined() && coo_->IsHypersparse())  // only allow coo for hypersparse graph.
  //   return SparseFormat::kCOO;
1561
1562
1563
  if (common)
    return DecodeFormat(common);
  return DecodeFormat(created);
1564
1565
}

1566
1567
1568
1569
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1570
  if (in_csr_->defined()) {
1571
    aten::CSRMatrix csc = GetCSCMatrix(0);
1572
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1573
  }
1574
  if (out_csr_->defined()) {
1575
    aten::CSRMatrix csr = GetCSRMatrix(0);
1576
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1577
  }
1578
  if (coo_->defined()) {
1579
1580
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1581
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1582
1583
1584
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1585
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1586
1587
1588
1589
1590
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1591
1592
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1593
  auto fmt = SelectFormat(ALL_CODE);
1594
1595
  switch (fmt) {
    case SparseFormat::kCOO: {
1596
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1597
1598
1599
1600
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1601
      return CreateFromCOO(1, coo);
1602
1603
1604
1605
1606
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1607
      return CreateFromCOO(1, coo);
1608
1609
1610
1611
1612
1613
1614
1615
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1616
1617
1618
1619
1620
1621
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1622

1623
  int64_t save_format_code, formats_code;
1624
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1625
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1626
  auto save_format = static_cast<SparseFormat>(save_format_code);
1627
1628
1629
1630
1631
1632
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
    case 0:
1633
      formats_ = ALL_CODE;
1634
1635
      break;
    case 1:
1636
      formats_ = COO_CODE;
1637
1638
      break;
    case 2:
1639
      formats_ = CSR_CODE;
1640
1641
      break;
    case 3:
1642
      formats_ = CSC_CODE;
1643
1644
1645
1646
1647
1648
      break;
    default:
      LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
        "not recognized.";
    }
  }
1649

1650
  switch (save_format) {
1651
    case SparseFormat::kCOO:
1652
1653
      fs->Read(&coo_);
      break;
1654
    case SparseFormat::kCSR:
1655
1656
      fs->Read(&out_csr_);
      break;
1657
    case SparseFormat::kCSC:
1658
1659
1660
1661
1662
1663
1664
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1675
1676
  meta_graph_ = GetAny()->meta_graph();

1677
1678
1679
  return true;
}

1680

1681
1682
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1683
1684
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1685
  auto avail_fmt = SelectFormat(ALL_CODE);
1686
  fs->Write(static_cast<int64_t>(avail_fmt));
1687
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1688
  switch (avail_fmt) {
1689
    case SparseFormat::kCOO:
1690
1691
      fs->Write(GetCOO());
      break;
1692
    case SparseFormat::kCSR:
1693
1694
      fs->Write(GetOutCSR());
      break;
1695
    case SparseFormat::kCSC:
1696
1697
1698
1699
1700
1701
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1702
1703
}

1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
    new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
  }

  return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}

1714
1715
1716
1717
1718
1719
1720
std::tuple<UnitGraphPtr, IdArray, IdArray>
UnitGraph::ToSimple() const {
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1721
  auto avail_fmt = SelectFormat(ALL_CODE);
1722
1723
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1724
      auto ret = aten::COOToSimple(GetCOO()->adj());
1725
1726
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1727
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1728
1729
1730
      break;
    }
    case SparseFormat::kCSR: {
1731
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1732
1733
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1734
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1735
1736
1737
      break;
    }
    case SparseFormat::kCSC: {
1738
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1739
1740
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1741
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

  return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
                         count,
                         edge_map);
}

1754
}  // namespace dgl