unit_graph.cc 56 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
Minjie Wang's avatar
Minjie Wang committed
66
    : BaseHeteroGraph(metagraph) {
67
68
69
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
70
71
72
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
        NullArray(),
        row_sorted, col_sorted};
73
  }
74

75
76
77
78
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
79
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
80
    adj_.data = aten::NullArray();
81
  }
82

83
84
85
86
87
88
89
90
91
92
93
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
94
95
  inline dgl_type_t SrcType() const {
    return 0;
96
  }
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
104
105
106
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
108
109
110
111
112
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
113
    LOG(FATAL) << "UnitGraph graph is not mutable.";
114
115
116
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
117
    LOG(FATAL) << "UnitGraph graph is not mutable.";
118
119
120
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
121
    LOG(FATAL) << "UnitGraph graph is not mutable.";
122
123
124
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
125
    LOG(FATAL) << "UnitGraph graph is not mutable.";
126
127
  }

128
129
130
131
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

132
133
134
135
  DLContext Context() const override {
    return adj_.row->ctx;
  }

136
  bool IsPinned() const override {
137
    return adj_.is_pinned;
138
139
  }

140
141
142
143
  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

144
145
146
147
148
149
150
151
152
153
154
155
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

156
  COO CopyTo(const DLContext &ctx) const {
157
158
    if (Context() == ctx)
      return *this;
159
    return COO(meta_graph_, adj_.CopyTo(ctx));
160
161
  }

162
163
164
165
166
167
168
169
170
171
172

  /*! \brief Pin the adj_: COOMatrix of the COO graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: COOMatrix of the COO graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

173
  bool IsMultigraph() const override {
174
    return aten::COOHasDuplicate(adj_);
175
176
177
178
179
180
181
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
182
    if (vtype == SrcType()) {
183
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
184
    } else if (vtype == DstType()) {
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
206
207
208
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
209
210
211
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
212
213
214
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
215
216
217
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
218
219
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
220
221
222
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
223
224
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
225
226
227
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
228
229
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
230
    return aten::COOGetAllData(adj_, src, dst);
231
232
  }

233
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
234
235
236
237
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
238
239
  }

240
241
242
243
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::COOGetData(adj_, src, dst);
  }

244
245
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
246
247
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
248
249
250
251
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
252
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
253
    BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
254
      "FindEdges requires the internal COO matrix not having EIDs.";
255
256
257
258
259
260
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
261
262
263
264
265
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
266
267
268
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
269
270
271
272
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
273
274
275
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
276
277
278
279
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
280
281
282
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
283
284
285
286
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
287
288
289
290
291
292
293
294
295
296
297
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
298
299
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
300
301
302
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
303
304
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
305
306
307
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
308
309
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
310
311
312
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
313
314
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

347
348
349
350
351
352
353
354
355
356
357
358
359
360
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

361
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
362
    LOG(FATAL) << "Not enabled for COO graph";
363
    return SparseFormat::kCOO;
364
365
  }

366
  dgl_format_code_t GetAllowedFormats() const override {
367
    LOG(FATAL) << "Not enabled for COO graph";
368
    return 0;
369
370
  }

371
  dgl_format_code_t GetCreatedFormats() const override {
372
373
374
375
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

376
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
377
378
379
380
381
382
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
383
384
    DLContext ctx = aten::GetContextOf(vids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
385
386
387
388
389
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
390
391
392
393
394
395
396
397
398
399
400
401
402
403
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
404
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
405
406
407
408
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
409
      subg.induced_vertices.emplace_back(
410
          aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
411
      subg.induced_vertices.emplace_back(
412
          aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
413
      subg.graph = std::make_shared<COO>(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
414
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
415
416
417
418
419
      subg.induced_edges = eids;
    }
    return subg;
  }

420
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
421
422
423
424
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

425
426
427
428
  aten::COOMatrix adj() const {
    return adj_;
  }

429
430
431
432
433
434
435
436
437
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

438
439
440
441
442
443
444
445
446
447
448
449
450
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

451
 private:
452
453
  friend class Serializer;

454
455
456
457
458
459
460
461
462
463
464
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
465
class UnitGraph::CSR : public BaseHeteroGraph {
466
 public:
Minjie Wang's avatar
Minjie Wang committed
467
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
468
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
469
    : BaseHeteroGraph(metagraph) {
470
471
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
472
473
474
    if (aten::IsValidIdArray(edge_ids))
      CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
        << "edge id arrays should have the same length as indices if not empty";
475
476
    CHECK_EQ(num_src, indptr->shape[0] - 1)
      << "number of nodes do not match the length of indptr minus 1.";
477

478
479
480
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

481
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
482
483
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
484

485
486
487
488
489
490
491
492
493
494
495
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
496
497
  inline dgl_type_t SrcType() const {
    return 0;
498
  }
Minjie Wang's avatar
Minjie Wang committed
499
500
501
502
503
504
505

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
506
507
508
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
509
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
510
511
512
513
514
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
515
    LOG(FATAL) << "UnitGraph graph is not mutable.";
516
517
518
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
519
    LOG(FATAL) << "UnitGraph graph is not mutable.";
520
521
522
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
523
    LOG(FATAL) << "UnitGraph graph is not mutable.";
524
525
526
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
527
    LOG(FATAL) << "UnitGraph graph is not mutable.";
528
529
  }

530
531
532
533
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

534
535
536
537
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

538
  bool IsPinned() const override {
539
    return adj_.is_pinned;
540
541
  }

542
543
544
545
  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

546
547
548
549
550
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
551
          meta_graph_,
552
553
554
555
556
557
558
559
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

560
  CSR CopyTo(const DLContext &ctx) const {
561
562
563
    if (Context() == ctx) {
      return *this;
    } else {
564
      return CSR(meta_graph_, adj_.CopyTo(ctx));
565
566
567
    }
  }

568
569
570
571
572
573
574
575
576
577
  /*! \brief Pin the adj_: CSRMatrix of the CSR graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: CSRMatrix of the CSR graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

578
  bool IsMultigraph() const override {
579
    return aten::CSRHasDuplicate(adj_);
580
581
582
583
584
585
586
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
587
    if (vtype == SrcType()) {
588
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
589
    } else if (vtype == DstType()) {
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
611
612
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
613
614
615
616
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
617
618
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
619
620
621
622
623
624
625
626
627
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
628
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
629
630
631
632
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
633
634
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
635
    return aten::CSRGetAllData(adj_, src, dst);
636
637
  }

638
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
639
640
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
641
642
643
644
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

645
646
647
648
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::CSRGetData(adj_, src, dst);
  }

649
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
650
    LOG(FATAL) << "Not enabled for CSR graph.";
651
652
653
654
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
655
    LOG(FATAL) << "Not enabled for CSR graph.";
656
657
658
659
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
660
    LOG(FATAL) << "Not enabled for CSR graph.";
661
662
663
664
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
665
    LOG(FATAL) << "Not enabled for CSR graph.";
666
667
668
669
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
670
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
671
672
673
674
675
676
677
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
678
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
679
680
681
682
683
684
685
686
687
688
689
690
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
691
692
693
694
695
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
696
697
698
699
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
700
    LOG(FATAL) << "Not enabled for CSR graph.";
701
702
703
704
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
705
    LOG(FATAL) << "Not enabled for CSR graph.";
706
707
708
709
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
710
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
711
712
713
714
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
715
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
716
717
718
719
720
721
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
722
    CHECK_EQ(NumBits(), 64);
723
724
725
726
727
728
729
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

730
731
732
733
734
735
736
737
738
739
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

740
741
742
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
743
    CHECK_EQ(NumBits(), 64);
744
745
746
747
748
749
750
751
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
752
    LOG(FATAL) << "Not enabled for CSR graph.";
753
754
755
756
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
757
    LOG(FATAL) << "Not enabled for CSR graph.";
758
759
760
761
762
763
764
765
766
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

767
768
769
770
771
772
773
774
775
776
777
778
779
780
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

781
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
782
    LOG(FATAL) << "Not enabled for CSR graph";
783
    return SparseFormat::kCSR;
784
785
  }

786
787
788
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
789
790
  }

791
  dgl_format_code_t GetCreatedFormats() const override {
792
793
794
795
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

796
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
797
798
799
800
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
801
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
802
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
803
804
    DLContext ctx = aten::GetContextOf(vids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
Minjie Wang's avatar
Minjie Wang committed
805
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
806
807
808
809
810
811
812
813
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
814
    LOG(FATAL) << "Not enabled for CSR graph.";
815
816
817
    return {};
  }

818
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
819
820
821
822
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

823
824
825
826
  aten::CSRMatrix adj() const {
    return adj_;
  }

827
828
829
830
831
832
833
834
835
836
837
838
839
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

840
 private:
841
842
  friend class Serializer;

843
844
845
846
847
848
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
849
// unit graph implementation
850
851
852
//
//////////////////////////////////////////////////////////

853
854
855
856
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
857
DLContext UnitGraph::Context() const {
858
859
860
  return GetAny()->Context();
}

861
862
863
864
bool UnitGraph::IsPinned() const {
  return GetAny()->IsPinned();
}

Minjie Wang's avatar
Minjie Wang committed
865
uint8_t UnitGraph::NumBits() const {
866
867
868
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
869
bool UnitGraph::IsMultigraph() const {
870
  const SparseFormat fmt = SelectFormat(CSC_CODE);
871
872
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
873
874
}

Minjie Wang's avatar
Minjie Wang committed
875
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
876
  const SparseFormat fmt = SelectFormat(ALL_CODE);
877
878
879
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
880
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
881
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
882
  return ptr->NumVertices(vtype);
883
884
}

Minjie Wang's avatar
Minjie Wang committed
885
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
886
887
888
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
889
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
890
  const SparseFormat fmt = SelectFormat(ALL_CODE);
891
  const auto ptr = GetFormat(fmt);
892
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
893
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
894
  return ptr->HasVertex(vtype, vid);
895
896
}

Minjie Wang's avatar
Minjie Wang committed
897
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
898
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
899
900
901
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
902
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
903
  const SparseFormat fmt = SelectFormat(CSC_CODE);
904
  const auto ptr = GetFormat(fmt);
905
  if (fmt == SparseFormat::kCSC)
906
907
908
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
909
910
}

Minjie Wang's avatar
Minjie Wang committed
911
BoolArray UnitGraph::HasEdgesBetween(
912
    dgl_type_t etype, IdArray src, IdArray dst) const {
913
  const SparseFormat fmt = SelectFormat(CSC_CODE);
914
  const auto ptr = GetFormat(fmt);
915
  if (fmt == SparseFormat::kCSC)
916
917
918
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
919
920
}

Minjie Wang's avatar
Minjie Wang committed
921
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
922
  const SparseFormat fmt = SelectFormat(CSC_CODE);
923
  const auto ptr = GetFormat(fmt);
924
  if (fmt == SparseFormat::kCSC)
925
926
927
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
928
929
}

Minjie Wang's avatar
Minjie Wang committed
930
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
931
  const SparseFormat fmt = SelectFormat(CSR_CODE);
932
933
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
934
935
}

Minjie Wang's avatar
Minjie Wang committed
936
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
937
  const SparseFormat fmt = SelectFormat(CSR_CODE);
938
  const auto ptr = GetFormat(fmt);
939
  if (fmt == SparseFormat::kCSC)
940
941
942
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
943
944
}

945
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
946
  const SparseFormat fmt = SelectFormat(CSR_CODE);
947
  const auto ptr = GetFormat(fmt);
948
  if (fmt == SparseFormat::kCSC) {
949
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
950
951
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
952
953
954
955
956
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
957
  const SparseFormat fmt = SelectFormat(CSR_CODE);
958
959
960
961
962
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
963
964
965
  }
}

Minjie Wang's avatar
Minjie Wang committed
966
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
967
  const SparseFormat fmt = SelectFormat(COO_CODE);
968
969
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
970
971
}

Minjie Wang's avatar
Minjie Wang committed
972
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
973
  const SparseFormat fmt = SelectFormat(COO_CODE);
974
975
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
976
977
}

Minjie Wang's avatar
Minjie Wang committed
978
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
979
  const SparseFormat fmt = SelectFormat(CSC_CODE);
980
  const auto ptr = GetFormat(fmt);
981
  if (fmt == SparseFormat::kCSC) {
982
983
984
985
986
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
987
988
}

Minjie Wang's avatar
Minjie Wang committed
989
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
990
  const SparseFormat fmt = SelectFormat(CSC_CODE);
991
  const auto ptr = GetFormat(fmt);
992
  if (fmt == SparseFormat::kCSC) {
993
994
995
996
997
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
998
999
}

Minjie Wang's avatar
Minjie Wang committed
1000
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
1001
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1002
1003
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
1004
1005
}

Minjie Wang's avatar
Minjie Wang committed
1006
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
1007
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1008
1009
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
1010
1011
}

Minjie Wang's avatar
Minjie Wang committed
1012
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
1013
1014
  SparseFormat fmt;
  if (order == std::string("eid")) {
1015
    fmt = SelectFormat(COO_CODE);
1016
  } else if (order.empty()) {
1017
    // arbitrary order
1018
    fmt = SelectFormat(ALL_CODE);
1019
  } else if (order == std::string("srcdst")) {
1020
    fmt = SelectFormat(CSR_CODE);
1021
1022
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
1023
    return {};
1024
  }
1025
1026

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
1027
  if (fmt == SparseFormat::kCSC)
1028
1029
1030
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1031
1032
}

Minjie Wang's avatar
Minjie Wang committed
1033
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1034
  SparseFormat fmt = SelectFormat(CSC_CODE);
1035
  const auto ptr = GetFormat(fmt);
1036
1037
1038
1039
1040
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1041
1042
}

Minjie Wang's avatar
Minjie Wang committed
1043
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1044
  SparseFormat fmt = SelectFormat(CSC_CODE);
1045
  const auto ptr = GetFormat(fmt);
1046
1047
1048
1049
1050
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1051
1052
}

Minjie Wang's avatar
Minjie Wang committed
1053
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1054
  SparseFormat fmt = SelectFormat(CSR_CODE);
1055
  const auto ptr = GetFormat(fmt);
1056
1057
1058
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1059
  return ptr->OutDegree(etype, vid);
1060
1061
}

Minjie Wang's avatar
Minjie Wang committed
1062
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1063
  SparseFormat fmt = SelectFormat(CSR_CODE);
1064
  const auto ptr = GetFormat(fmt);
1065
1066
1067
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1068
  return ptr->OutDegrees(etype, vids);
1069
1070
}

Minjie Wang's avatar
Minjie Wang committed
1071
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1072
  SparseFormat fmt = SelectFormat(CSR_CODE);
1073
1074
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1075
1076
}

1077
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1078
  SparseFormat fmt = SelectFormat(CSR_CODE);
1079
1080
1081
1082
1083
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1084
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1085
  SparseFormat fmt = SelectFormat(CSR_CODE);
1086
1087
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1088
1089
}

Minjie Wang's avatar
Minjie Wang committed
1090
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1091
  SparseFormat fmt = SelectFormat(CSC_CODE);
1092
  const auto ptr = GetFormat(fmt);
1093
  if (fmt == SparseFormat::kCSC)
1094
1095
1096
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1097
1098
}

Minjie Wang's avatar
Minjie Wang committed
1099
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1100
  SparseFormat fmt = SelectFormat(CSC_CODE);
1101
  const auto ptr = GetFormat(fmt);
1102
  if (fmt == SparseFormat::kCSC)
1103
1104
1105
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1106
1107
}

Minjie Wang's avatar
Minjie Wang committed
1108
std::vector<IdArray> UnitGraph::GetAdj(
1109
1110
1111
1112
1113
1114
1115
1116
1117
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
1118
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1119
1120
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
1121
    return GetCOO()->GetAdj(etype, transpose, fmt);
1122
1123
1124
1125
1126
1127
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
1128
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
1129
  // We prefer to generate a subgraph from out-csr.
1130
  SparseFormat fmt = SelectFormat(CSR_CODE);
1131
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1132
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1153
1154
1155
1156
1157
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1158
HeteroSubgraph UnitGraph::EdgeSubgraph(
1159
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1160
  SparseFormat fmt = SelectFormat(COO_CODE);
1161
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1162
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1183
1184
1185
1186
1187
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1188
HeteroGraphPtr UnitGraph::CreateFromCOO(
1189
1190
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1191
    bool row_sorted, bool col_sorted,
1192
    dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1193
1194
1195
1196
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1197
1198
  COOPtr coo(new COO(mg, num_src, num_dst, row, col,
      row_sorted, col_sorted));
1199
1200

  return HeteroGraphPtr(
1201
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1202
1203
}

1204
1205
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
1206
    dgl_format_code_t formats) {
1207
1208
1209
1210
1211
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1212

1213
  return HeteroGraphPtr(
1214
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1215
1216
}

Minjie Wang's avatar
Minjie Wang committed
1217
1218
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1219
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1220
1221
1222
1223
1224
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1225
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1226
1227
}

1228
1229
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1230
    dgl_format_code_t formats) {
1231
1232
1233
1234
1235
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1236
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1237
1238
}

1239
1240
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1241
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1242
1243
1244
1245
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1246
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1247
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1248
1249
1250
1251
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1252
    dgl_format_code_t formats) {
1253
1254
1255
1256
1257
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1258
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1259
1260
}

Minjie Wang's avatar
Minjie Wang committed
1261
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1262
1263
1264
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1265
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1266
    CHECK_NOTNULL(bg);
1267
1268
1269
1270
1271
1272
    CSRPtr new_incsr =
      (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
    CSRPtr new_outcsr =
      (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
    COOPtr new_coo =
      (bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
1273
    return HeteroGraphPtr(
1274
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1275
1276
1277
  }
}

1278
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) {
1279
1280
  if (ctx == g->Context()) {
    return g;
1281
1282
1283
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1284
    CSRPtr new_incsr = (bg->in_csr_->defined())
1285
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
1286
1287
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
1288
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
1289
1290
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
1291
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
1292
                         : nullptr;
1293
    return HeteroGraphPtr(
1294
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1295
1296
1297
  }
}

1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
void UnitGraph::PinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->PinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->PinMemory_();
  if (this->coo_->defined())
    this->coo_->PinMemory_();
}

void UnitGraph::UnpinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->UnpinMemory_();
  if (this->coo_->defined())
    this->coo_->UnpinMemory_();
}

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
void UnitGraph::InvalidateCSR() {
  this->out_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCSC() {
  this->in_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCOO() {
  this->coo_ = COOPtr(new COO());
}

1328
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
1329
                     dgl_format_code_t formats)
Minjie Wang's avatar
Minjie Wang committed
1330
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1331
1332
1333
1334
1335
1336
1337
1338
1339
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1340
1341
1342
1343
1344
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) <<
      ", which is not compatible with available formats: " << CodeToStr(formats);
1345
1346
1347
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1348
1349
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
    int num_vtypes,
1350
1351
1352
1353
1354
1355
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
1356
    dgl_format_code_t formats) {
1357
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1358
1359
1360
1361
1362
1363
1364

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1365
1366
  else
    in_csr_ptr = CSRPtr(new CSR());
1367
1368
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1369
1370
  else
    out_csr_ptr = CSRPtr(new CSR());
1371
1372
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1373
1374
  else
    coo_ptr = COOPtr(new COO());
1375

1376
  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1377
1378
}

1379
1380
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1381
    if (!(formats_ & CSC_CODE))
1382
1383
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSC matrix.";
1384
  CSRPtr ret = in_csr_;
1385
1386
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1387
  if (!in_csr_->defined()) {
1388
1389
1390
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(
            aten::COOTranspose(coo_->adj()));
1391

1392
      if (inplace)
1393
1394
1395
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1396
    } else {
1397
1398
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1399

1400
      if (inplace)
1401
1402
1403
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1404
    }
1405
1406
    if (inplace && IsPinned())
      in_csr_->PinMemory_();
1407
  }
1408
  return ret;
1409
1410
1411
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
1412
1413
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1414
    if (!(formats_ & CSR_CODE))
1415
1416
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSR matrix.";
1417
  CSRPtr ret = out_csr_;
1418
1419
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1420
  if (!out_csr_->defined()) {
1421
1422
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1423

1424
      if (inplace)
1425
1426
1427
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1428
    } else {
1429
1430
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1431

1432
      if (inplace)
1433
1434
1435
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1436
    }
1437
1438
    if (inplace && IsPinned())
      out_csr_->PinMemory_();
1439
  }
1440
  return ret;
1441
1442
1443
}

/* !\brief Return coo. If not exist, create from csr.*/
1444
1445
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1446
    if (!(formats_ & COO_CODE))
1447
1448
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create COO matrix.";
1449
  COOPtr ret = coo_;
1450
1451
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1452
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1453

1454
      if (inplace)
1455
1456
1457
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1458
    } else {
1459
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1460
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1461

1462
      if (inplace)
1463
1464
1465
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1466
    }
1467
1468
    if (inplace && IsPinned())
      coo_->PinMemory_();
1469
  }
1470
  return ret;
1471
1472
}

1473
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1474
1475
1476
  return GetInCSR()->adj();
}

1477
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1478
1479
1480
  return GetOutCSR()->adj();
}

1481
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1482
1483
1484
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1485
HeteroGraphPtr UnitGraph::GetAny() const {
1486
  if (in_csr_->defined()) {
1487
    return in_csr_;
1488
  } else if (out_csr_->defined()) {
1489
1490
1491
1492
1493
1494
    return out_csr_;
  } else {
    return coo_;
  }
}

1495
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1496
  dgl_format_code_t ret = 0;
1497
  if (in_csr_->defined())
1498
    ret |= CSC_CODE;
1499
  if (out_csr_->defined())
1500
    ret |= CSR_CODE;
1501
  if (coo_->defined())
1502
    ret |= COO_CODE;
1503
1504
1505
  return ret;
}

1506
1507
1508
1509
dgl_format_code_t UnitGraph::GetAllowedFormats() const {
  return formats_;
}

1510
1511
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1512
1513
1514
1515
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
1516
  default:
1517
    return GetCOO();
1518
1519
1520
  }
}

1521
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1522
  if (formats == ALL_CODE)
1523
    return HeteroGraphPtr(
1524
1525
        // TODO(xiangsx) Make it as graph storage.Clone()
        new UnitGraph(meta_graph_,
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
                      (in_csr_->defined())
                          ? CSRPtr(new CSR(*in_csr_))
                          : nullptr,
                      (out_csr_->defined())
                          ? CSRPtr(new CSR(*out_csr_))
                          : nullptr,
                      (coo_->defined())
                          ? COOPtr(new COO(*coo_))
                          : nullptr,
                      formats));
  int64_t num_vtypes = NumVertexTypes();
1537
  if (formats & COO_CODE)
1538
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1539
  if (formats & CSR_CODE)
1540
1541
1542
1543
1544
1545
1546
1547
1548
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const {
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
  if (common & created)
    return DecodeFormat(common & created);
1549
1550
1551
1552
1553

  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have
  // not been implmented yet.
  // if (coo_->defined() && coo_->IsHypersparse())  // only allow coo for hypersparse graph.
  //   return SparseFormat::kCOO;
1554
1555
1556
  if (common)
    return DecodeFormat(common);
  return DecodeFormat(created);
1557
1558
}

1559
1560
1561
1562
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1563
  if (in_csr_->defined()) {
1564
    aten::CSRMatrix csc = GetCSCMatrix(0);
1565
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1566
  }
1567
  if (out_csr_->defined()) {
1568
    aten::CSRMatrix csr = GetCSRMatrix(0);
1569
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1570
  }
1571
  if (coo_->defined()) {
1572
1573
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1574
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1575
1576
1577
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1578
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1579
1580
1581
1582
1583
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1584
1585
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1586
  auto fmt = SelectFormat(ALL_CODE);
1587
1588
  switch (fmt) {
    case SparseFormat::kCOO: {
1589
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1590
1591
1592
1593
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1594
      return CreateFromCOO(1, coo);
1595
1596
1597
1598
1599
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1600
      return CreateFromCOO(1, coo);
1601
1602
1603
1604
1605
1606
1607
1608
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1609
1610
1611
1612
1613
1614
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1615

1616
  int64_t save_format_code, formats_code;
1617
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1618
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1619
  auto save_format = static_cast<SparseFormat>(save_format_code);
1620
1621
1622
1623
1624
1625
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
    case 0:
1626
      formats_ = ALL_CODE;
1627
1628
      break;
    case 1:
1629
      formats_ = COO_CODE;
1630
1631
      break;
    case 2:
1632
      formats_ = CSR_CODE;
1633
1634
      break;
    case 3:
1635
      formats_ = CSC_CODE;
1636
1637
1638
1639
1640
1641
      break;
    default:
      LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
        "not recognized.";
    }
  }
1642

1643
  switch (save_format) {
1644
    case SparseFormat::kCOO:
1645
1646
      fs->Read(&coo_);
      break;
1647
    case SparseFormat::kCSR:
1648
1649
      fs->Read(&out_csr_);
      break;
1650
    case SparseFormat::kCSC:
1651
1652
1653
1654
1655
1656
1657
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1668
1669
  meta_graph_ = GetAny()->meta_graph();

1670
1671
1672
  return true;
}

1673

1674
1675
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1676
1677
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1678
  auto avail_fmt = SelectFormat(ALL_CODE);
1679
  fs->Write(static_cast<int64_t>(avail_fmt));
1680
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1681
  switch (avail_fmt) {
1682
    case SparseFormat::kCOO:
1683
1684
      fs->Write(GetCOO());
      break;
1685
    case SparseFormat::kCSR:
1686
1687
      fs->Write(GetOutCSR());
      break;
1688
    case SparseFormat::kCSC:
1689
1690
1691
1692
1693
1694
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1695
1696
}

1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
    new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
  }

  return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}

1707
1708
1709
1710
1711
1712
1713
std::tuple<UnitGraphPtr, IdArray, IdArray>
UnitGraph::ToSimple() const {
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1714
  auto avail_fmt = SelectFormat(ALL_CODE);
1715
1716
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1717
      auto ret = aten::COOToSimple(GetCOO()->adj());
1718
1719
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1720
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1721
1722
1723
      break;
    }
    case SparseFormat::kCSR: {
1724
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1725
1726
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1727
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1728
1729
1730
      break;
    }
    case SparseFormat::kCSC: {
1731
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1732
1733
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1734
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

  return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
                         count,
                         edge_map);
}

1747
}  // namespace dgl