unit_graph.cc 53.5 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
Minjie Wang's avatar
Minjie Wang committed
66
    : BaseHeteroGraph(metagraph) {
67
68
69
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
70
71
72
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
        NullArray(),
        row_sorted, col_sorted};
73
  }
74

75
76
77
78
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
79
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
80
    adj_.data = aten::NullArray();
81
  }
82

83
84
85
86
87
88
89
90
91
92
93
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
94
95
  inline dgl_type_t SrcType() const {
    return 0;
96
  }
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
104
105
106
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
108
109
110
111
112
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
113
    LOG(FATAL) << "UnitGraph graph is not mutable.";
114
115
116
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
117
    LOG(FATAL) << "UnitGraph graph is not mutable.";
118
119
120
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
121
    LOG(FATAL) << "UnitGraph graph is not mutable.";
122
123
124
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
125
    LOG(FATAL) << "UnitGraph graph is not mutable.";
126
127
  }

128
129
130
131
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

132
133
134
135
136
137
138
139
  DLContext Context() const override {
    return adj_.row->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

  COO CopyTo(const DLContext& ctx) const {
    if (Context() == ctx)
      return *this;
155
    return COO(meta_graph_, adj_.CopyTo(ctx));
156
157
  }

158
  bool IsMultigraph() const override {
159
    return aten::COOHasDuplicate(adj_);
160
161
162
163
164
165
166
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
167
    if (vtype == SrcType()) {
168
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
169
    } else if (vtype == DstType()) {
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
191
192
193
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
194
195
196
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
197
198
199
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
200
201
202
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
203
204
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
205
206
207
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
208
209
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
210
211
212
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
213
214
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
215
    return aten::COOGetAllData(adj_, src, dst);
216
217
  }

218
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
219
220
221
222
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
223
224
  }

225
226
227
228
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::COOGetData(adj_, src, dst);
  }

229
230
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
231
232
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
233
234
235
236
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
237
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
238
239
    BUG_ON(aten::IsNullArray(adj_.data)) <<
      "FindEdges requires the internal COO matrix not having EIDs.";
240
241
242
243
244
245
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
246
247
248
249
250
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
251
252
253
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
254
255
256
257
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
258
259
260
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
261
262
263
264
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
265
266
267
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
268
269
270
271
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
272
273
274
275
276
277
278
279
280
281
282
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
283
284
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
285
286
287
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
288
289
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
290
291
292
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
293
294
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
295
296
297
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
298
299
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

332
333
334
335
336
337
338
339
340
341
342
343
344
345
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

346
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
347
    LOG(FATAL) << "Not enabled for COO graph";
348
    return SparseFormat::kCOO;
349
350
  }

351
  dgl_format_code_t GetAllowedFormats() const override {
352
    LOG(FATAL) << "Not enabled for COO graph";
353
    return 0;
354
355
  }

356
  dgl_format_code_t GetCreatedFormats() const override {
357
358
359
360
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

361
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
362
363
364
365
366
367
368
369
370
371
372
373
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
374
375
376
377
378
379
380
381
382
383
384
385
386
387
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
388
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
389
390
391
392
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
393
394
395
396
      subg.induced_vertices.emplace_back(
          aten::Range(0, NumVertices(SrcType()), NumBits(), Context()));
      subg.induced_vertices.emplace_back(
          aten::Range(0, NumVertices(DstType()), NumBits(), Context()));
397
      subg.graph = std::make_shared<COO>(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
398
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
399
400
401
402
403
      subg.induced_edges = eids;
    }
    return subg;
  }

404
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
405
406
407
408
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

409
410
411
412
  aten::COOMatrix adj() const {
    return adj_;
  }

413
414
415
416
417
418
419
420
421
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

422
423
424
425
426
427
428
429
430
431
432
433
434
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

435
 private:
436
437
  friend class Serializer;

438
439
440
441
442
443
444
445
446
447
448
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
449
class UnitGraph::CSR : public BaseHeteroGraph {
450
 public:
Minjie Wang's avatar
Minjie Wang committed
451
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
452
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
453
    : BaseHeteroGraph(metagraph) {
454
455
456
457
458
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
    CHECK(aten::IsValidIdArray(edge_ids));
    CHECK_EQ(indices->shape[0], edge_ids->shape[0])
      << "indices and edge id arrays should have the same length";
459

460
461
462
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

463
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
464
465
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
466

467
468
469
470
471
472
473
474
475
476
477
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
478
479
  inline dgl_type_t SrcType() const {
    return 0;
480
  }
Minjie Wang's avatar
Minjie Wang committed
481
482
483
484
485
486
487

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
488
489
490
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
491
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
492
493
494
495
496
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
497
    LOG(FATAL) << "UnitGraph graph is not mutable.";
498
499
500
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
501
    LOG(FATAL) << "UnitGraph graph is not mutable.";
502
503
504
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
505
    LOG(FATAL) << "UnitGraph graph is not mutable.";
506
507
508
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
509
    LOG(FATAL) << "UnitGraph graph is not mutable.";
510
511
  }

512
513
514
515
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

516
517
518
519
520
521
522
523
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

524
525
526
527
528
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
529
          meta_graph_,
530
531
532
533
534
535
536
537
538
539
540
541
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

  CSR CopyTo(const DLContext& ctx) const {
    if (Context() == ctx) {
      return *this;
    } else {
542
      return CSR(meta_graph_, adj_.CopyTo(ctx));
543
544
545
    }
  }

546
  bool IsMultigraph() const override {
547
    return aten::CSRHasDuplicate(adj_);
548
549
550
551
552
553
554
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
555
    if (vtype == SrcType()) {
556
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
557
    } else if (vtype == DstType()) {
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
579
580
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
581
582
583
584
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
585
586
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
587
588
589
590
591
592
593
594
595
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
596
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
597
598
599
600
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
601
602
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
603
    return aten::CSRGetAllData(adj_, src, dst);
604
605
  }

606
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
607
608
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
609
610
611
612
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

613
614
615
616
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::CSRGetData(adj_, src, dst);
  }

617
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
618
    LOG(FATAL) << "Not enabled for CSR graph.";
619
620
621
622
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
623
    LOG(FATAL) << "Not enabled for CSR graph.";
624
625
626
627
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
628
    LOG(FATAL) << "Not enabled for CSR graph.";
629
630
631
632
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
633
    LOG(FATAL) << "Not enabled for CSR graph.";
634
635
636
637
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
638
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
639
640
641
642
643
644
645
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
646
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
647
648
649
650
651
652
653
654
655
656
657
658
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
659
660
661
662
663
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
664
665
666
667
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
668
    LOG(FATAL) << "Not enabled for CSR graph.";
669
670
671
672
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
673
    LOG(FATAL) << "Not enabled for CSR graph.";
674
675
676
677
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
678
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
679
680
681
682
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
683
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
684
685
686
687
688
689
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
690
    CHECK_EQ(NumBits(), 64);
691
692
693
694
695
696
697
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

698
699
700
701
702
703
704
705
706
707
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

708
709
710
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
711
    CHECK_EQ(NumBits(), 64);
712
713
714
715
716
717
718
719
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
720
    LOG(FATAL) << "Not enabled for CSR graph.";
721
722
723
724
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
725
    LOG(FATAL) << "Not enabled for CSR graph.";
726
727
728
729
730
731
732
733
734
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

735
736
737
738
739
740
741
742
743
744
745
746
747
748
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

749
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
750
    LOG(FATAL) << "Not enabled for CSR graph";
751
    return SparseFormat::kCSR;
752
753
  }

754
755
756
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
757
758
  }

759
  dgl_format_code_t GetCreatedFormats() const override {
760
761
762
763
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

764
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
765
766
767
768
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
769
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
770
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
771
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
Minjie Wang's avatar
Minjie Wang committed
772
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
773
774
775
776
777
778
779
780
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
781
    LOG(FATAL) << "Not enabled for CSR graph.";
782
783
784
    return {};
  }

785
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
786
787
788
789
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

790
791
792
793
  aten::CSRMatrix adj() const {
    return adj_;
  }

794
795
796
797
798
799
800
801
802
803
804
805
806
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

807
 private:
808
809
  friend class Serializer;

810
811
812
813
814
815
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
816
// unit graph implementation
817
818
819
//
//////////////////////////////////////////////////////////

820
821
822
823
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
824
DLContext UnitGraph::Context() const {
825
826
827
  return GetAny()->Context();
}

Minjie Wang's avatar
Minjie Wang committed
828
uint8_t UnitGraph::NumBits() const {
829
830
831
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
832
bool UnitGraph::IsMultigraph() const {
833
  const SparseFormat fmt = SelectFormat(CSC_CODE);
834
835
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
836
837
}

Minjie Wang's avatar
Minjie Wang committed
838
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
839
  const SparseFormat fmt = SelectFormat(ALL_CODE);
840
841
842
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
843
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
844
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
845
  return ptr->NumVertices(vtype);
846
847
}

Minjie Wang's avatar
Minjie Wang committed
848
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
849
850
851
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
852
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
853
  const SparseFormat fmt = SelectFormat(ALL_CODE);
854
  const auto ptr = GetFormat(fmt);
855
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
856
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
857
  return ptr->HasVertex(vtype, vid);
858
859
}

Minjie Wang's avatar
Minjie Wang committed
860
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
861
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
862
863
864
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
865
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
866
  const SparseFormat fmt = SelectFormat(CSC_CODE);
867
  const auto ptr = GetFormat(fmt);
868
  if (fmt == SparseFormat::kCSC)
869
870
871
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
872
873
}

Minjie Wang's avatar
Minjie Wang committed
874
BoolArray UnitGraph::HasEdgesBetween(
875
    dgl_type_t etype, IdArray src, IdArray dst) const {
876
  const SparseFormat fmt = SelectFormat(CSC_CODE);
877
  const auto ptr = GetFormat(fmt);
878
  if (fmt == SparseFormat::kCSC)
879
880
881
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
882
883
}

Minjie Wang's avatar
Minjie Wang committed
884
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
885
  const SparseFormat fmt = SelectFormat(CSC_CODE);
886
  const auto ptr = GetFormat(fmt);
887
  if (fmt == SparseFormat::kCSC)
888
889
890
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
891
892
}

Minjie Wang's avatar
Minjie Wang committed
893
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
894
  const SparseFormat fmt = SelectFormat(CSR_CODE);
895
896
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
897
898
}

Minjie Wang's avatar
Minjie Wang committed
899
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
900
  const SparseFormat fmt = SelectFormat(CSR_CODE);
901
  const auto ptr = GetFormat(fmt);
902
  if (fmt == SparseFormat::kCSC)
903
904
905
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
906
907
}

908
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
909
  const SparseFormat fmt = SelectFormat(CSR_CODE);
910
  const auto ptr = GetFormat(fmt);
911
  if (fmt == SparseFormat::kCSC) {
912
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
913
914
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
915
916
917
918
919
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
920
  const SparseFormat fmt = SelectFormat(CSR_CODE);
921
922
923
924
925
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
926
927
928
  }
}

Minjie Wang's avatar
Minjie Wang committed
929
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
930
  const SparseFormat fmt = SelectFormat(COO_CODE);
931
932
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
933
934
}

Minjie Wang's avatar
Minjie Wang committed
935
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
936
  const SparseFormat fmt = SelectFormat(COO_CODE);
937
938
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
939
940
}

Minjie Wang's avatar
Minjie Wang committed
941
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
942
  const SparseFormat fmt = SelectFormat(CSC_CODE);
943
  const auto ptr = GetFormat(fmt);
944
  if (fmt == SparseFormat::kCSC) {
945
946
947
948
949
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
950
951
}

Minjie Wang's avatar
Minjie Wang committed
952
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
953
  const SparseFormat fmt = SelectFormat(CSC_CODE);
954
  const auto ptr = GetFormat(fmt);
955
  if (fmt == SparseFormat::kCSC) {
956
957
958
959
960
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
961
962
}

Minjie Wang's avatar
Minjie Wang committed
963
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
964
  const SparseFormat fmt = SelectFormat(CSR_CODE);
965
966
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
967
968
}

Minjie Wang's avatar
Minjie Wang committed
969
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
970
  const SparseFormat fmt = SelectFormat(CSR_CODE);
971
972
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
973
974
}

Minjie Wang's avatar
Minjie Wang committed
975
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
976
977
  SparseFormat fmt;
  if (order == std::string("eid")) {
978
    fmt = SelectFormat(COO_CODE);
979
  } else if (order.empty()) {
980
    // arbitrary order
981
    fmt = SelectFormat(ALL_CODE);
982
  } else if (order == std::string("srcdst")) {
983
    fmt = SelectFormat(CSR_CODE);
984
985
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
986
    return {};
987
  }
988
989

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
990
  if (fmt == SparseFormat::kCSC)
991
992
993
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
994
995
}

Minjie Wang's avatar
Minjie Wang committed
996
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
997
  SparseFormat fmt = SelectFormat(CSC_CODE);
998
  const auto ptr = GetFormat(fmt);
999
  if (fmt == SparseFormat::kCSC)
1000
1001
1002
    return ptr->OutDegree(etype, vid);
  else
    return ptr->InDegree(etype, vid);
1003
1004
}

Minjie Wang's avatar
Minjie Wang committed
1005
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1006
  SparseFormat fmt = SelectFormat(CSC_CODE);
1007
  const auto ptr = GetFormat(fmt);
1008
  if (fmt == SparseFormat::kCSC)
1009
1010
1011
    return ptr->OutDegrees(etype, vids);
  else
    return ptr->InDegrees(etype, vids);
1012
1013
}

Minjie Wang's avatar
Minjie Wang committed
1014
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1015
  SparseFormat fmt = SelectFormat(CSR_CODE);
1016
1017
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegree(etype, vid);
1018
1019
}

Minjie Wang's avatar
Minjie Wang committed
1020
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1021
  SparseFormat fmt = SelectFormat(CSR_CODE);
1022
1023
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegrees(etype, vids);
1024
1025
}

Minjie Wang's avatar
Minjie Wang committed
1026
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1027
  SparseFormat fmt = SelectFormat(CSR_CODE);
1028
1029
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1030
1031
}

1032
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1033
  SparseFormat fmt = SelectFormat(CSR_CODE);
1034
1035
1036
1037
1038
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1039
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1040
  SparseFormat fmt = SelectFormat(CSR_CODE);
1041
1042
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1043
1044
}

Minjie Wang's avatar
Minjie Wang committed
1045
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1046
  SparseFormat fmt = SelectFormat(CSC_CODE);
1047
  const auto ptr = GetFormat(fmt);
1048
  if (fmt == SparseFormat::kCSC)
1049
1050
1051
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1052
1053
}

Minjie Wang's avatar
Minjie Wang committed
1054
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1055
  SparseFormat fmt = SelectFormat(CSC_CODE);
1056
  const auto ptr = GetFormat(fmt);
1057
  if (fmt == SparseFormat::kCSC)
1058
1059
1060
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1061
1062
}

Minjie Wang's avatar
Minjie Wang committed
1063
std::vector<IdArray> UnitGraph::GetAdj(
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(etype, false, "csr")
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(etype, !transpose, fmt);
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
1083
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
1084
  // We prefer to generate a subgraph from out-csr.
1085
  SparseFormat fmt = SelectFormat(CSR_CODE);
1086
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1087
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1108
1109
1110
1111
1112
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1113
HeteroSubgraph UnitGraph::EdgeSubgraph(
1114
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1115
  SparseFormat fmt = SelectFormat(COO_CODE);
1116
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1117
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1138
1139
1140
1141
1142
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1143
HeteroGraphPtr UnitGraph::CreateFromCOO(
1144
1145
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1146
    bool row_sorted, bool col_sorted,
1147
    dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1148
1149
1150
1151
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1152
1153
  COOPtr coo(new COO(mg, num_src, num_dst, row, col,
      row_sorted, col_sorted));
1154
1155

  return HeteroGraphPtr(
1156
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1157
1158
}

1159
1160
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
1161
    dgl_format_code_t formats) {
1162
1163
1164
1165
1166
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1167

1168
  return HeteroGraphPtr(
1169
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1170
1171
}

Minjie Wang's avatar
Minjie Wang committed
1172
1173
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1174
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1175
1176
1177
1178
1179
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1180
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1181
1182
}

1183
1184
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1185
    dgl_format_code_t formats) {
1186
1187
1188
1189
1190
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1191
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1192
1193
}

1194
1195
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1196
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1197
1198
1199
1200
1201
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1202
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1203
1204
1205
1206
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1207
    dgl_format_code_t formats) {
1208
1209
1210
1211
1212
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1213
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1214
1215
}

Minjie Wang's avatar
Minjie Wang committed
1216
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1217
1218
1219
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1220
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1221
    CHECK_NOTNULL(bg);
1222
1223
1224
1225
1226
1227
    CSRPtr new_incsr =
      (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
    CSRPtr new_outcsr =
      (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
    COOPtr new_coo =
      (bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
1228
    return HeteroGraphPtr(
1229
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1230
1231
1232
  }
}

Minjie Wang's avatar
Minjie Wang committed
1233
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
1234
1235
  if (ctx == g->Context()) {
    return g;
1236
1237
1238
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1239
1240
1241
1242
1243
1244
    CSRPtr new_incsr =
      (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
    CSRPtr new_outcsr =
      (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
    COOPtr new_coo =
      (bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
1245
    return HeteroGraphPtr(
1246
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1247
1248
1249
  }
}

1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
void UnitGraph::InvalidateCSR() {
  this->out_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCSC() {
  this->in_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCOO() {
  this->coo_ = COOPtr(new COO());
}

1262
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
1263
                     dgl_format_code_t formats)
Minjie Wang's avatar
Minjie Wang committed
1264
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1265
1266
1267
1268
1269
1270
1271
1272
1273
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1274
1275
1276
1277
1278
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) <<
      ", which is not compatible with available formats: " << CodeToStr(formats);
1279
1280
1281
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1282
1283
1284
1285
1286
1287
1288
HeteroGraphPtr UnitGraph::CreateHomographFrom(
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
1289
    dgl_format_code_t formats) {
1290
1291
1292
1293
1294
1295
1296
1297
  auto mg = CreateUnitGraphMetaGraph1();

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1298
1299
  else
    in_csr_ptr = CSRPtr(new CSR());
1300
1301
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1302
1303
  else
    out_csr_ptr = CSRPtr(new CSR());
1304
1305
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1306
1307
  else
    coo_ptr = COOPtr(new COO());
1308

1309
  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1310
1311
}

1312
1313
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1314
    if (!(formats_ & CSC_CODE))
1315
1316
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSC matrix.";
1317
  CSRPtr ret = in_csr_;
1318
1319
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1320
  if (!in_csr_->defined()) {
1321
1322
1323
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(
            aten::COOTranspose(coo_->adj()));
1324

1325
      if (inplace)
1326
1327
1328
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1329
    } else {
1330
1331
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1332

1333
      if (inplace)
1334
1335
1336
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1337
1338
    }
  }
1339
  return ret;
1340
1341
1342
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
1343
1344
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1345
    if (!(formats_ & CSR_CODE))
1346
1347
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSR matrix.";
1348
  CSRPtr ret = out_csr_;
1349
1350
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1351
  if (!out_csr_->defined()) {
1352
1353
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1354

1355
      if (inplace)
1356
1357
1358
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1359
    } else {
1360
1361
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1362

1363
      if (inplace)
1364
1365
1366
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1367
1368
    }
  }
1369
  return ret;
1370
1371
1372
}

/* !\brief Return coo. If not exist, create from csr.*/
1373
1374
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1375
    if (!(formats_ & COO_CODE))
1376
1377
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create COO matrix.";
1378
  COOPtr ret = coo_;
1379
1380
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1381
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1382

1383
      if (inplace)
1384
1385
1386
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1387
    } else {
1388
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1389
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1390

1391
      if (inplace)
1392
1393
1394
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1395
1396
    }
  }
1397
  return ret;
1398
1399
}

1400
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1401
1402
1403
  return GetInCSR()->adj();
}

1404
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1405
1406
1407
  return GetOutCSR()->adj();
}

1408
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1409
1410
1411
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1412
HeteroGraphPtr UnitGraph::GetAny() const {
1413
  if (in_csr_->defined()) {
1414
    return in_csr_;
1415
  } else if (out_csr_->defined()) {
1416
1417
1418
1419
1420
1421
    return out_csr_;
  } else {
    return coo_;
  }
}

1422
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1423
  dgl_format_code_t ret = 0;
1424
  if (in_csr_->defined())
1425
    ret |= CSC_CODE;
1426
  if (out_csr_->defined())
1427
    ret |= CSR_CODE;
1428
  if (coo_->defined())
1429
    ret |= COO_CODE;
1430
1431
1432
  return ret;
}

1433
1434
1435
1436
dgl_format_code_t UnitGraph::GetAllowedFormats() const {
  return formats_;
}

1437
1438
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1439
1440
1441
1442
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
1443
  default:
1444
    return GetCOO();
1445
1446
1447
  }
}

1448
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1449
  if (formats == ALL_CODE)
1450
    return HeteroGraphPtr(
1451
1452
        // TODO(xiangsx) Make it as graph storage.Clone()
        new UnitGraph(meta_graph_,
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
                      (in_csr_->defined())
                          ? CSRPtr(new CSR(*in_csr_))
                          : nullptr,
                      (out_csr_->defined())
                          ? CSRPtr(new CSR(*out_csr_))
                          : nullptr,
                      (coo_->defined())
                          ? COOPtr(new COO(*coo_))
                          : nullptr,
                      formats));
  int64_t num_vtypes = NumVertexTypes();
1464
  if (formats & COO_CODE)
1465
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1466
  if (formats & CSR_CODE)
1467
1468
1469
1470
1471
1472
1473
1474
1475
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const {
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
  if (common & created)
    return DecodeFormat(common & created);
1476
1477
1478
1479
1480

  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have
  // not been implmented yet.
  // if (coo_->defined() && coo_->IsHypersparse())  // only allow coo for hypersparse graph.
  //   return SparseFormat::kCOO;
1481
1482
1483
  if (common)
    return DecodeFormat(common);
  return DecodeFormat(created);
1484
1485
}

1486
1487
1488
1489
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1490
  if (in_csr_->defined()) {
1491
    aten::CSRMatrix csc = GetCSCMatrix(0);
1492
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1493
  }
1494
  if (out_csr_->defined()) {
1495
    aten::CSRMatrix csr = GetCSRMatrix(0);
1496
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1497
  }
1498
  if (coo_->defined()) {
1499
1500
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1501
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1502
1503
1504
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1505
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1506
1507
1508
1509
1510
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1511
1512
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1513
  auto fmt = SelectFormat(ALL_CODE);
1514
1515
  switch (fmt) {
    case SparseFormat::kCOO: {
1516
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1517
1518
1519
1520
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1521
      return CreateFromCOO(1, coo);
1522
1523
1524
1525
1526
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1527
      return CreateFromCOO(1, coo);
1528
1529
1530
1531
1532
1533
1534
1535
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1536
1537
1538
1539
1540
1541
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1542

1543
  int64_t save_format_code, formats_code;
1544
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1545
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1546
  auto save_format = static_cast<SparseFormat>(save_format_code);
1547
1548
1549
1550
1551
1552
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
    case 0:
1553
      formats_ = ALL_CODE;
1554
1555
      break;
    case 1:
1556
      formats_ = COO_CODE;
1557
1558
      break;
    case 2:
1559
      formats_ = CSR_CODE;
1560
1561
      break;
    case 3:
1562
      formats_ = CSC_CODE;
1563
1564
1565
1566
1567
1568
      break;
    default:
      LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
        "not recognized.";
    }
  }
1569

1570
  switch (save_format) {
1571
    case SparseFormat::kCOO:
1572
1573
      fs->Read(&coo_);
      break;
1574
    case SparseFormat::kCSR:
1575
1576
      fs->Read(&out_csr_);
      break;
1577
    case SparseFormat::kCSC:
1578
1579
1580
1581
1582
1583
1584
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1595
1596
  meta_graph_ = GetAny()->meta_graph();

1597
1598
1599
  return true;
}

1600

1601
1602
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1603
1604
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1605
  auto avail_fmt = SelectFormat(ALL_CODE);
1606
  fs->Write(static_cast<int64_t>(avail_fmt));
1607
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1608
  switch (avail_fmt) {
1609
    case SparseFormat::kCOO:
1610
1611
      fs->Write(GetCOO());
      break;
1612
    case SparseFormat::kCSR:
1613
1614
      fs->Write(GetOutCSR());
      break;
1615
    case SparseFormat::kCSC:
1616
1617
1618
1619
1620
1621
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1622
1623
}

1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
    new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
  }

  return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}

1634
1635
1636
1637
1638
1639
1640
std::tuple<UnitGraphPtr, IdArray, IdArray>
UnitGraph::ToSimple() const {
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1641
  auto avail_fmt = SelectFormat(ALL_CODE);
1642
1643
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1644
      auto ret = aten::COOToSimple(GetCOO()->adj());
1645
1646
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1647
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1648
1649
1650
      break;
    }
    case SparseFormat::kCSR: {
1651
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1652
1653
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1654
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1655
1656
1657
      break;
    }
    case SparseFormat::kCSC: {
1658
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1659
1660
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1661
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

  return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
                         count,
                         edge_map);
}

1674
}  // namespace dgl