unit_graph.cc 57.3 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/unit_graph.cc
 * @brief UnitGraph graph implementation
5
 */
6
7
#include "./unit_graph.h"

8
#include <dgl/array.h>
9
#include <dgl/base_heterograph.h>
10
11
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
12
13

#include "../c_api_common.h"
14
#include "./serialize/dglstream.h"
15
16

namespace dgl {
17

18
namespace {
19
20
21

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
22
23
24
25
26
27
28
29
30
31
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
32

Minjie Wang's avatar
Minjie Wang committed
33
34
35
36
37
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
38
39
40
41
42
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
43
44
45
46
47
48
49
50
51
52
53
54

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
55
56

};  // namespace
57
58
59
60
61
62
63

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
64
class UnitGraph::COO : public BaseHeteroGraph {
65
 public:
66
67
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
68
      : BaseHeteroGraph(metagraph) {
69
70
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
71
72
73
74
    CHECK_EQ(src->shape[0], dst->shape[0])
        << "Input arrays should have the same length.";
    adj_ = aten::COOMatrix{num_src,     num_dst,    src,       dst,
                           NullArray(), row_sorted, col_sorted};
75
  }
76

77
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
78
      : BaseHeteroGraph(metagraph), adj_(coo) {
79
80
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
81
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
82
    adj_.data = aten::NullArray();
83
  }
84

85
86
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
87
88
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
89
90
91
92
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

93
  bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }
94

95
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
96

97
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
98

99
  inline dgl_type_t EdgeType() const { return 0; }
100
101

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
102
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
103
               << "The relation graph is simply this graph itself.";
104
105
106
107
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
108
    LOG(FATAL) << "UnitGraph graph is not mutable.";
109
110
111
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
112
    LOG(FATAL) << "UnitGraph graph is not mutable.";
113
114
115
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
116
    LOG(FATAL) << "UnitGraph graph is not mutable.";
117
118
  }

119
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
120

121
  DGLDataType DataType() const override { return adj_.row->dtype; }
122

123
  DGLContext Context() const override { return adj_.row->ctx; }
124

125
  bool IsPinned() const override { return adj_.is_pinned; }
126

127
  uint8_t NumBits() const override { return adj_.row->dtype.bits; }
128

129
  COO AsNumBits(uint8_t bits) const {
130
    if (NumBits() == bits) return *this;
131
132

    COO ret(
133
134
        meta_graph_, adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));
135
136
137
    return ret;
  }

138
139
  COO CopyTo(const DGLContext& ctx) const {
    if (Context() == ctx) return *this;
140
    return COO(meta_graph_, adj_.CopyTo(ctx));
141
142
  }

143
  /** @brief Pin the adj_: COOMatrix of the COO graph. */
144
  void PinMemory_() { adj_.PinMemory_(); }
145

146
  /** @brief Unpin the adj_: COOMatrix of the COO graph. */
147
  void UnpinMemory_() { adj_.UnpinMemory_(); }
148

149
  /** @brief Record stream for the adj_: COOMatrix of the COO graph. */
150
151
152
153
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

154
  bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }
155

156
  bool IsReadonly() const override { return true; }
157
158

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
159
    if (vtype == SrcType()) {
160
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
161
    } else if (vtype == DstType()) {
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

182
183
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
184
185
186
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
187
188
  }

189
190
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
191
192
193
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
194
195
196
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
197
198
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
199
200
201
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
202
203
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
204
205
206
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
207
208
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
209
    return aten::COOGetAllData(adj_, src, dst);
210
211
  }

212
213
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
214
215
216
217
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
218
219
  }

220
221
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
222
223
224
    return aten::COOGetData(adj_, src, dst);
  }

225
226
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
227
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
228
229
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
230
231
232
233
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
234
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
235
236
237
238
239
    BUG_IF_FAIL(aten::IsNullArray(adj_.data))
        << "FindEdges requires the internal COO matrix not having EIDs.";
    return EdgeArray{
        aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
        eids};
240
241
242
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
243
    IdArray ret_src, ret_eid;
244
245
246
247
    std::tie(ret_eid, ret_src) =
        aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);
    IdArray ret_dst =
        aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
248
    return EdgeArray{ret_src, ret_dst, ret_eid};
249
250
251
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
252
253
254
255
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
256
257
258
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
259
260
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
261
262
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
263
    return EdgeArray{ret_src, ret_dst, ret_eid};
264
265
266
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
267
268
269
270
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
271
272
  }

273
274
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
275
    CHECK(order.empty() || order == std::string("eid"))
276
277
        << "COO only support Edges of order \"eid\", but got \"" << order
        << "\".";
278
279
280
281
282
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
283
284
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
285
286
287
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
288
289
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
290
291
292
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
293
294
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
295
296
297
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
298
299
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
323
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
324
325
326
327
328
329
330
331
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

332
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }
333
334
335
336
337
338
339
340
341
342
343

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

344
345
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
346
    LOG(FATAL) << "Not enabled for COO graph";
347
    return SparseFormat::kCOO;
348
349
  }

350
  dgl_format_code_t GetAllowedFormats() const override {
351
    LOG(FATAL) << "Not enabled for COO graph";
352
    return 0;
353
354
  }

355
  dgl_format_code_t GetCreatedFormats() const override {
356
357
358
359
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

360
361
362
363
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
364
365
366
367
368
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
369
    DGLContext ctx = aten::GetContextOf(vids);
370
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
371
372
    subg.graph = std::make_shared<COO>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);
373
374
375
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
376
377
378
  }

  HeteroSubgraph EdgeSubgraph(
379
380
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
381
382
383
384
385
386
387
388
389
390
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
391
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
392
393
394
395
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
396
      subg.induced_vertices.emplace_back(
397
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
398
      subg.induced_vertices.emplace_back(
399
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
400
      subg.graph = std::make_shared<COO>(
401
402
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,
          new_dst);
403
404
405
406
407
      subg.induced_edges = eids;
    }
    return subg;
  }

408
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
409
410
411
412
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

413
  aten::COOMatrix adj() const { return adj_; }
414

415
  /**
416
417
   * @brief Determines whether the graph is "hypersparse", i.e. having
   * significantly more nodes than edges.
418
419
420
421
422
423
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

424
425
426
427
428
429
430
431
432
433
434
435
436
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

437
 private:
438
439
  friend class Serializer;

440
  /** @brief internal adjacency matrix. Data array is empty */
441
442
443
444
445
446
447
448
449
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

450
/** @brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
451
class UnitGraph::CSR : public BaseHeteroGraph {
452
 public:
453
454
455
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,
      IdArray indices, IdArray edge_ids)
      : BaseHeteroGraph(metagraph) {
456
457
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
458
    if (aten::IsValidIdArray(edge_ids))
459
460
461
462
463
      CHECK(
          (indices->shape[0] == edge_ids->shape[0]) ||
          aten::IsNullArray(edge_ids))
          << "edge id arrays should have the same length as indices if not "
             "empty";
464
    CHECK_EQ(num_src, indptr->shape[0] - 1)
465
        << "number of nodes do not match the length of indptr minus 1.";
466

467
468
469
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

470
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
471
      : BaseHeteroGraph(metagraph), adj_(csr) {}
472

473
474
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
475
476
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
477
478
479
480
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

481
  bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }
482

483
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
484

485
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
486

487
  inline dgl_type_t EdgeType() const { return 0; }
488
489

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
490
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
491
               << "The relation graph is simply this graph itself.";
492
493
494
495
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
496
    LOG(FATAL) << "UnitGraph graph is not mutable.";
497
498
499
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
500
    LOG(FATAL) << "UnitGraph graph is not mutable.";
501
502
503
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
504
    LOG(FATAL) << "UnitGraph graph is not mutable.";
505
506
  }

507
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
508

509
  DGLDataType DataType() const override { return adj_.indices->dtype; }
510

511
  DGLContext Context() const override { return adj_.indices->ctx; }
512

513
  bool IsPinned() const override { return adj_.is_pinned; }
514

515
  uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
516

517
518
519
520
521
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
522
          meta_graph_, adj_.num_rows, adj_.num_cols,
523
524
525
526
527
528
529
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

530
  CSR CopyTo(const DGLContext& ctx) const {
531
532
533
    if (Context() == ctx) {
      return *this;
    } else {
534
      return CSR(meta_graph_, adj_.CopyTo(ctx));
535
536
537
    }
  }

538
  /** @brief Pin the adj_: CSRMatrix of the CSR graph. */
539
  void PinMemory_() { adj_.PinMemory_(); }
540

541
  /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
542
  void UnpinMemory_() { adj_.UnpinMemory_(); }
543

544
  /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
545
546
547
548
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

549
  bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }
550

551
  bool IsReadonly() const override { return true; }
552
553

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
554
    if (vtype == SrcType()) {
555
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
556
    } else if (vtype == DstType()) {
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

577
578
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
579
580
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
581
582
583
    return aten::CSRIsNonZero(adj_, src, dst);
  }

584
585
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
586
587
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
588
589
590
591
592
593
594
595
596
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
597
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
598
599
600
601
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
602
603
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
604
    return aten::CSRGetAllData(adj_, src, dst);
605
606
  }

607
608
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
609
610
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
611
612
613
614
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

615
616
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
617
618
619
    return aten::CSRGetData(adj_, src, dst);
  }

620
621
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
622
    LOG(FATAL) << "Not enabled for CSR graph.";
623
624
625
626
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
627
    LOG(FATAL) << "Not enabled for CSR graph.";
628
629
630
631
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
632
    LOG(FATAL) << "Not enabled for CSR graph.";
633
634
635
636
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
637
    LOG(FATAL) << "Not enabled for CSR graph.";
638
639
640
641
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
642
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
643
644
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
645
646
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
647
648
649
650
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
651
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
652
653
654
655
656
657
658
659
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

660
661
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
662
    CHECK(order.empty() || order == std::string("srcdst"))
663
664
        << "CSR only support Edges of order \"srcdst\","
        << " but got \"" << order << "\".";
665
666
667
668
669
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
670
671
672
673
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
674
    LOG(FATAL) << "Not enabled for CSR graph.";
675
676
677
678
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
679
    LOG(FATAL) << "Not enabled for CSR graph.";
680
681
682
683
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
684
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
685
686
687
688
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
689
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
690
691
692
693
694
695
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
696
    CHECK_EQ(NumBits(), 64);
697
698
699
700
701
702
703
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

704
705
706
707
708
709
710
711
712
713
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

714
715
716
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
717
    CHECK_EQ(NumBits(), 64);
718
719
720
721
722
723
724
725
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
726
    LOG(FATAL) << "Not enabled for CSR graph.";
727
728
729
730
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
731
    LOG(FATAL) << "Not enabled for CSR graph.";
732
733
734
735
    return {};
  }

  std::vector<IdArray> GetAdj(
736
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
737
738
739
740
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

741
742
743
744
745
746
747
748
749
750
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

751
  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }
752

753
754
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
755
    LOG(FATAL) << "Not enabled for CSR graph";
756
    return SparseFormat::kCSR;
757
758
  }

759
760
761
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
762
763
  }

764
  dgl_format_code_t GetCreatedFormats() const override {
765
766
767
768
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

769
770
771
772
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
Minjie Wang's avatar
Minjie Wang committed
773
774
775
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
776
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
777
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
778
    DGLContext ctx = aten::GetContextOf(vids);
779
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
780
781
782
    subg.graph = std::make_shared<CSR>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,
        submat.indices, sub_eids);
783
784
785
786
787
788
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
789
790
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
791
    LOG(FATAL) << "Not enabled for CSR graph.";
792
793
794
    return {};
  }

795
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
796
797
798
799
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

800
  aten::CSRMatrix adj() const { return adj_; }
801

802
803
804
805
806
807
808
809
810
811
812
813
814
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

815
 private:
816
817
  friend class Serializer;

818
  /** @brief internal adjacency matrix. Data array stores edge ids */
819
820
821
822
823
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
824
// unit graph implementation
825
826
827
//
//////////////////////////////////////////////////////////

828
DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }
829

830
DGLContext UnitGraph::Context() const { return GetAny()->Context(); }
831

832
bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }
833

834
uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }
835

Minjie Wang's avatar
Minjie Wang committed
836
bool UnitGraph::IsMultigraph() const {
837
  const SparseFormat fmt = SelectFormat(CSC_CODE);
838
839
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
840
841
}

Minjie Wang's avatar
Minjie Wang committed
842
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
843
  const SparseFormat fmt = SelectFormat(ALL_CODE);
844
845
846
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
847
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
848
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
849
  return ptr->NumVertices(vtype);
850
851
}

Minjie Wang's avatar
Minjie Wang committed
852
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
853
854
855
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
856
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
857
  const SparseFormat fmt = SelectFormat(ALL_CODE);
858
  const auto ptr = GetFormat(fmt);
859
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
860
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
861
  return ptr->HasVertex(vtype, vid);
862
863
}

Minjie Wang's avatar
Minjie Wang committed
864
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
865
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
866
867
868
  return aten::LT(vids, NumVertices(vtype));
}

869
870
bool UnitGraph::HasEdgeBetween(
    dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
871
  const SparseFormat fmt = SelectFormat(CSC_CODE);
872
  const auto ptr = GetFormat(fmt);
873
  if (fmt == SparseFormat::kCSC)
874
875
876
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
877
878
}

Minjie Wang's avatar
Minjie Wang committed
879
BoolArray UnitGraph::HasEdgesBetween(
880
    dgl_type_t etype, IdArray src, IdArray dst) const {
881
  const SparseFormat fmt = SelectFormat(CSC_CODE);
882
  const auto ptr = GetFormat(fmt);
883
  if (fmt == SparseFormat::kCSC)
884
885
886
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
887
888
}

Minjie Wang's avatar
Minjie Wang committed
889
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
890
  const SparseFormat fmt = SelectFormat(CSC_CODE);
891
  const auto ptr = GetFormat(fmt);
892
  if (fmt == SparseFormat::kCSC)
893
894
895
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
896
897
}

Minjie Wang's avatar
Minjie Wang committed
898
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
899
  const SparseFormat fmt = SelectFormat(CSR_CODE);
900
901
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
902
903
}

Minjie Wang's avatar
Minjie Wang committed
904
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
905
  const SparseFormat fmt = SelectFormat(CSR_CODE);
906
  const auto ptr = GetFormat(fmt);
907
  if (fmt == SparseFormat::kCSC)
908
909
910
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
911
912
}

913
914
EdgeArray UnitGraph::EdgeIdsAll(
    dgl_type_t etype, IdArray src, IdArray dst) const {
915
  const SparseFormat fmt = SelectFormat(CSR_CODE);
916
  const auto ptr = GetFormat(fmt);
917
  if (fmt == SparseFormat::kCSC) {
918
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
919
920
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
921
922
923
924
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

925
926
IdArray UnitGraph::EdgeIdsOne(
    dgl_type_t etype, IdArray src, IdArray dst) const {
927
  const SparseFormat fmt = SelectFormat(CSR_CODE);
928
929
930
931
932
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
933
934
935
  }
}

936
937
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(
    dgl_type_t etype, dgl_id_t eid) const {
938
  const SparseFormat fmt = SelectFormat(COO_CODE);
939
940
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
941
942
}

Minjie Wang's avatar
Minjie Wang committed
943
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
944
  const SparseFormat fmt = SelectFormat(COO_CODE);
945
946
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
947
948
}

Minjie Wang's avatar
Minjie Wang committed
949
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
950
  const SparseFormat fmt = SelectFormat(CSC_CODE);
951
  const auto ptr = GetFormat(fmt);
952
  if (fmt == SparseFormat::kCSC) {
953
954
955
956
957
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
958
959
}

Minjie Wang's avatar
Minjie Wang committed
960
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
961
  const SparseFormat fmt = SelectFormat(CSC_CODE);
962
  const auto ptr = GetFormat(fmt);
963
  if (fmt == SparseFormat::kCSC) {
964
965
966
967
968
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
969
970
}

Minjie Wang's avatar
Minjie Wang committed
971
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
972
  const SparseFormat fmt = SelectFormat(CSR_CODE);
973
974
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
975
976
}

Minjie Wang's avatar
Minjie Wang committed
977
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
978
  const SparseFormat fmt = SelectFormat(CSR_CODE);
979
980
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
981
982
}

983
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {
984
985
  SparseFormat fmt;
  if (order == std::string("eid")) {
986
    fmt = SelectFormat(COO_CODE);
987
  } else if (order.empty()) {
988
    // arbitrary order
989
    fmt = SelectFormat(ALL_CODE);
990
  } else if (order == std::string("srcdst")) {
991
    fmt = SelectFormat(CSR_CODE);
992
993
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
994
    return {};
995
  }
996
997

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
998
  if (fmt == SparseFormat::kCSC)
999
1000
1001
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1002
1003
}

Minjie Wang's avatar
Minjie Wang committed
1004
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1005
  SparseFormat fmt = SelectFormat(CSC_CODE);
1006
  const auto ptr = GetFormat(fmt);
1007
1008
1009
1010
1011
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1012
1013
}

Minjie Wang's avatar
Minjie Wang committed
1014
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1015
  SparseFormat fmt = SelectFormat(CSC_CODE);
1016
  const auto ptr = GetFormat(fmt);
1017
1018
1019
1020
1021
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1022
1023
}

Minjie Wang's avatar
Minjie Wang committed
1024
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1025
  SparseFormat fmt = SelectFormat(CSR_CODE);
1026
  const auto ptr = GetFormat(fmt);
1027
1028
1029
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1030
  return ptr->OutDegree(etype, vid);
1031
1032
}

Minjie Wang's avatar
Minjie Wang committed
1033
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1034
  SparseFormat fmt = SelectFormat(CSR_CODE);
1035
  const auto ptr = GetFormat(fmt);
1036
1037
1038
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1039
  return ptr->OutDegrees(etype, vids);
1040
1041
}

Minjie Wang's avatar
Minjie Wang committed
1042
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1043
  SparseFormat fmt = SelectFormat(CSR_CODE);
1044
1045
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1046
1047
}

1048
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1049
  SparseFormat fmt = SelectFormat(CSR_CODE);
1050
1051
1052
1053
1054
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1055
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1056
  SparseFormat fmt = SelectFormat(CSR_CODE);
1057
1058
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1059
1060
}

Minjie Wang's avatar
Minjie Wang committed
1061
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1062
  SparseFormat fmt = SelectFormat(CSC_CODE);
1063
  const auto ptr = GetFormat(fmt);
1064
  if (fmt == SparseFormat::kCSC)
1065
1066
1067
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1068
1069
}

Minjie Wang's avatar
Minjie Wang committed
1070
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1071
  SparseFormat fmt = SelectFormat(CSC_CODE);
1072
  const auto ptr = GetFormat(fmt);
1073
  if (fmt == SparseFormat::kCSC)
1074
1075
1076
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1077
1078
}

Minjie Wang's avatar
Minjie Wang committed
1079
std::vector<IdArray> UnitGraph::GetAdj(
1080
1081
1082
1083
1084
1085
1086
1087
1088
    dgl_type_t etype, bool transpose, const std::string& fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst
  // nodes and col for src nodes. Therefore, we need to flip the transpose flag.
  // For example,
  //   transpose=False is equal to in edge CSR. We have this behavior because
  //   previously we use framework's SPMM and we don't cache reverse adj. This
  //   is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
  //   change the behavior and make row for src and col for dst.
1089
  if (fmt == std::string("csr")) {
1090
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1091
                      : GetInCSR()->GetAdj(etype, false, "csr");
1092
  } else if (fmt == std::string("coo")) {
1093
    return GetCOO()->GetAdj(etype, transpose, fmt);
1094
1095
1096
1097
1098
1099
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

1100
1101
HeteroSubgraph UnitGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
1102
  // We prefer to generate a subgraph from out-csr.
1103
  SparseFormat fmt = SelectFormat(CSR_CODE);
1104
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1105
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1125
1126
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1127
1128
1129
1130
1131
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1132
HeteroSubgraph UnitGraph::EdgeSubgraph(
1133
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1134
  SparseFormat fmt = SelectFormat(COO_CODE);
1135
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1136
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1156
1157
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1158
1159
1160
1161
1162
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1163
HeteroGraphPtr UnitGraph::CreateFromCOO(
1164
1165
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
    IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1166
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1167
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1168
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1169
  COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));
1170

1171
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1172
1173
}

1174
HeteroGraphPtr UnitGraph::CreateFromCOO(
1175
    int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
1176
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1177
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1178
1179
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1180

1181
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1182
1183
}

Minjie Wang's avatar
Minjie Wang committed
1184
HeteroGraphPtr UnitGraph::CreateFromCSR(
1185
1186
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1187
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1188
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1189
1190
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1191
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1192
1193
}

1194
HeteroGraphPtr UnitGraph::CreateFromCSR(
1195
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1196
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1197
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1198
1199
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1200
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1201
1202
}

1203
HeteroGraphPtr UnitGraph::CreateFromCSC(
1204
1205
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1206
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1207
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
1208
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1209
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1210
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1211
1212
1213
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
1214
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1215
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1216
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1217
1218
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1219
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1220
1221
}

Minjie Wang's avatar
Minjie Wang committed
1222
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1223
1224
1225
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1226
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1227
    CHECK_NOTNULL(bg);
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
    CSRPtr new_incsr = (bg->in_csr_->defined())
                           ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
                            ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
                         ? COOPtr(new COO(bg->coo_->AsNumBits(bits)))
                         : nullptr;
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1239
1240
1241
  }
}

1242
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
1243
1244
  if (ctx == g->Context()) {
    return g;
1245
1246
1247
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1248
    CSRPtr new_incsr = (bg->in_csr_->defined())
1249
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
1250
1251
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
1252
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
1253
1254
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
1255
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
1256
                         : nullptr;
1257
1258
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1259
1260
1261
  }
}

1262
void UnitGraph::PinMemory_() {
1263
1264
1265
  if (this->in_csr_->defined()) this->in_csr_->PinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->PinMemory_();
  if (this->coo_->defined()) this->coo_->PinMemory_();
1266
1267
1268
}

void UnitGraph::UnpinMemory_() {
1269
1270
1271
  if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();
  if (this->coo_->defined()) this->coo_->UnpinMemory_();
1272
1273
}

1274
void UnitGraph::RecordStream(DGLStreamHandle stream) {
1275
1276
1277
  if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);
  if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);
  if (this->coo_->defined()) this->coo_->RecordStream(stream);
1278
1279
1280
  this->recorded_streams.push_back(stream);
}

1281
void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }
1282

1283
void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }
1284

1285
void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }
1286

1287
1288
1289
1290
1291
1292
1293
UnitGraph::UnitGraph(
    GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
    dgl_format_code_t formats)
    : BaseHeteroGraph(metagraph),
      in_csr_(in_csr),
      out_csr_(out_csr),
      coo_(coo) {
1294
1295
1296
1297
1298
1299
1300
1301
1302
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1303
1304
1305
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
1306
1307
1308
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created)
               << ", which is not compatible with available formats: "
               << CodeToStr(formats);
1309
1310
1311
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1312
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
1313
1314
1315
    int num_vtypes, const aten::CSRMatrix& in_csr,
    const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,
    bool has_out_csr, bool has_coo, dgl_format_code_t formats) {
1316
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1317
1318
1319
1320
1321
1322
1323

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1324
1325
  else
    in_csr_ptr = CSRPtr(new CSR());
1326
1327
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1328
1329
  else
    out_csr_ptr = CSRPtr(new CSR());
1330
1331
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1332
1333
  else
    coo_ptr = COOPtr(new COO());
1334

1335
1336
  return HeteroGraphPtr(
      new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1337
1338
}

1339
1340
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1341
    if (!(formats_ & CSC_CODE))
1342
1343
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSC matrix.";
1344
  CSRPtr ret = in_csr_;
1345
1346
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1347
  if (!in_csr_->defined()) {
1348
    if (coo_->defined()) {
1349
      const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));
1350

1351
      if (inplace)
1352
1353
1354
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1355
    } else {
1356
1357
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1358

1359
      if (inplace)
1360
1361
1362
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1363
    }
1364
    if (inplace) {
1365
1366
      if (IsPinned()) in_csr_->PinMemory_();
      for (auto stream : recorded_streams) in_csr_->RecordStream(stream);
1367
    }
1368
  }
1369
  return ret;
1370
1371
}

1372
/** @brief Return out csr. If not exist, transpose the other one.*/
1373
1374
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1375
    if (!(formats_ & CSR_CODE))
1376
1377
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSR matrix.";
1378
  CSRPtr ret = out_csr_;
1379
1380
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1381
  if (!out_csr_->defined()) {
1382
1383
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1384

1385
      if (inplace)
1386
1387
1388
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1389
    } else {
1390
1391
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1392

1393
      if (inplace)
1394
1395
1396
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1397
    }
1398
    if (inplace) {
1399
1400
      if (IsPinned()) out_csr_->PinMemory_();
      for (auto stream : recorded_streams) out_csr_->RecordStream(stream);
1401
    }
1402
  }
1403
  return ret;
1404
1405
}

1406
/** @brief Return coo. If not exist, create from csr.*/
1407
1408
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1409
    if (!(formats_ & COO_CODE))
1410
1411
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create COO matrix.";
1412
  COOPtr ret = coo_;
1413
1414
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1415
1416
      const auto& newadj =
          aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1417

1418
      if (inplace)
1419
1420
1421
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1422
    } else {
1423
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1424
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1425

1426
      if (inplace)
1427
1428
1429
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1430
    }
1431
    if (inplace) {
1432
1433
      if (IsPinned()) coo_->PinMemory_();
      for (auto stream : recorded_streams) coo_->RecordStream(stream);
1434
    }
1435
  }
1436
  return ret;
1437
1438
}

1439
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1440
1441
1442
  return GetInCSR()->adj();
}

1443
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1444
1445
1446
  return GetOutCSR()->adj();
}

1447
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1448
1449
1450
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1451
HeteroGraphPtr UnitGraph::GetAny() const {
1452
  if (in_csr_->defined()) {
1453
    return in_csr_;
1454
  } else if (out_csr_->defined()) {
1455
1456
1457
1458
1459
1460
    return out_csr_;
  } else {
    return coo_;
  }
}

1461
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1462
  dgl_format_code_t ret = 0;
1463
1464
1465
  if (in_csr_->defined()) ret |= CSC_CODE;
  if (out_csr_->defined()) ret |= CSR_CODE;
  if (coo_->defined()) ret |= COO_CODE;
1466
1467
1468
  return ret;
}

1469
dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }
1470

1471
1472
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1473
1474
1475
1476
1477
1478
    case SparseFormat::kCSR:
      return GetOutCSR();
    case SparseFormat::kCSC:
      return GetInCSR();
    default:
      return GetCOO();
1479
1480
1481
  }
}

1482
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1483
  if (formats == ALL_CODE)
1484
    return HeteroGraphPtr(
1485
        // TODO(xiangsx) Make it as graph storage.Clone()
1486
1487
1488
1489
1490
        new UnitGraph(
            meta_graph_,
            (in_csr_->defined()) ? CSRPtr(new CSR(*in_csr_)) : nullptr,
            (out_csr_->defined()) ? CSRPtr(new CSR(*out_csr_)) : nullptr,
            (coo_->defined()) ? COOPtr(new COO(*coo_)) : nullptr, formats));
1491
  int64_t num_vtypes = NumVertexTypes();
1492
  if (formats & COO_CODE)
1493
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1494
  if (formats & CSR_CODE)
1495
1496
1497
1498
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

1499
1500
SparseFormat UnitGraph::SelectFormat(
    dgl_format_code_t preferred_formats) const {
1501
1502
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
1503
  if (common & created) return DecodeFormat(common & created);
1504

1505
1506
1507
  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on
  // COO have not been implmented yet. if (coo_->defined() &&
  // coo_->IsHypersparse())  // only allow coo for hypersparse graph.
1508
  //   return SparseFormat::kCOO;
1509
  if (common) return DecodeFormat(common);
1510
  return DecodeFormat(created);
1511
1512
}

1513
1514
1515
1516
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1517
  if (in_csr_->defined()) {
1518
    aten::CSRMatrix csc = GetCSCMatrix(0);
1519
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1520
  }
1521
  if (out_csr_->defined()) {
1522
    aten::CSRMatrix csr = GetCSRMatrix(0);
1523
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1524
  }
1525
  if (coo_->defined()) {
1526
1527
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1528
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1529
1530
1531
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1532
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1533
1534
1535
1536
1537
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1538
1539
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1540
  auto fmt = SelectFormat(ALL_CODE);
1541
1542
  switch (fmt) {
    case SparseFormat::kCOO: {
1543
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1544
1545
1546
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
1547
1548
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1549
      return CreateFromCOO(1, coo);
1550
1551
1552
1553
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
1554
1555
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1556
      return CreateFromCOO(1, coo);
1557
1558
1559
1560
1561
1562
1563
1564
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1565
1566
1567
1568
1569
1570
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1571

1572
  int64_t save_format_code, formats_code;
1573
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1574
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1575
1576
1577
1578
1579
1580
1581
1582
  dgl_format_code_t save_formats = ANY_CODE;
  if (save_format_code >> 32) {
    save_formats =
        static_cast<dgl_format_code_t>(0xffffffff & save_format_code);
  } else {
    save_formats =
        SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});
  }
1583
1584
1585
1586
1587
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
      case 0:
        formats_ = ALL_CODE;
        break;
      case 1:
        formats_ = COO_CODE;
        break;
      case 2:
        formats_ = CSR_CODE;
        break;
      case 3:
        formats_ = CSC_CODE;
        break;
      default:
        LOG(FATAL) << "Load graph failed, formats code " << formats_code
                   << "not recognized.";
1603
1604
    }
  }
1605

1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
  if (save_formats & COO_CODE) {
    fs->Read(&coo_);
  }
  if (save_formats & CSR_CODE) {
    fs->Read(&out_csr_);
  }
  if (save_formats & CSC_CODE) {
    fs->Read(&in_csr_);
  }
  if (!coo_ && !out_csr_ && !in_csr_) {
    LOG(FATAL) << "unsupported format code";
1617
1618
  }

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1629
1630
  meta_graph_ = GetAny()->meta_graph();

1631
1632
1633
1634
1635
  return true;
}

void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1636
1637
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1638
  auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
1639
  auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);
1640
1641
1642
1643
1644
1645
1646
  if (fstream) {
    auto formats = fstream->FormatsToSave();
    save_formats = formats == ANY_CODE
                       ? SparseFormatsToCode({SelectFormat(ALL_CODE)})
                       : formats;
  }
  fs->Write(static_cast<int64_t>(save_formats | 0x100000000));
1647
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1648
1649
1650
1651
1652
1653
1654
1655
  if (save_formats & COO_CODE) {
    fs->Write(GetCOO());
  }
  if (save_formats & CSR_CODE) {
    fs->Write(GetOutCSR());
  }
  if (save_formats & CSC_CODE) {
    fs->Write(GetInCSR());
1656
  }
1657
1658
}

1659
1660
1661
1662
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
1663
1664
    new_coo =
        COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
1665
1666
  }

1667
1668
  return UnitGraphPtr(
      new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
1669
1670
}

1671
std::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {
1672
1673
1674
1675
1676
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1677
  auto avail_fmt = SelectFormat(ALL_CODE);
1678
1679
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1680
      auto ret = aten::COOToSimple(GetCOO()->adj());
1681
1682
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1683
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1684
1685
1686
      break;
    }
    case SparseFormat::kCSR: {
1687
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1688
1689
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1690
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1691
1692
1693
      break;
    }
    case SparseFormat::kCSC: {
1694
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1695
1696
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1697
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1698
1699
1700
1701
1702
1703
1704
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

1705
1706
1707
  return std::make_tuple(
      UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
      count, edge_map);
1708
1709
}

1710
}  // namespace dgl