unit_graph.cc 60.2 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019 by Contributors
4
5
 * @file graph/unit_graph.cc
 * @brief UnitGraph graph implementation
6
 */
sangwzh's avatar
sangwzh committed
7
#include "unit_graph.h"
8

9
#include <dgl/array.h>
10
#include <dgl/base_heterograph.h>
11
12
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
13
14

#include "../c_api_common.h"
sangwzh's avatar
sangwzh committed
15
#include "serialize/dglstream.h"
16
17

namespace dgl {
18

19
namespace {
20
21
22

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
23
24
25
26
27
28
29
30
31
32
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
33

Minjie Wang's avatar
Minjie Wang committed
34
35
36
37
38
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
39
40
41
42
43
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
44
45
46
47
48
49
50
51
52
53
54
55

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
56
57

};  // namespace
58
59
60
61
62
63
64

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
65
class UnitGraph::COO : public BaseHeteroGraph {
66
 public:
67
68
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
69
      : BaseHeteroGraph(metagraph) {
70
71
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
72
73
74
75
    CHECK_EQ(src->shape[0], dst->shape[0])
        << "Input arrays should have the same length.";
    adj_ = aten::COOMatrix{num_src,     num_dst,    src,       dst,
                           NullArray(), row_sorted, col_sorted};
76
  }
77

78
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
79
      : BaseHeteroGraph(metagraph), adj_(coo) {
80
81
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
82
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
83
    adj_.data = aten::NullArray();
84
  }
85

86
87
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
88
89
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
90
91
92
93
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

94
  bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }
95

96
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
97

98
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
99

100
  inline dgl_type_t EdgeType() const { return 0; }
101
102

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
103
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
104
               << "The relation graph is simply this graph itself.";
105
106
107
108
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
109
    LOG(FATAL) << "UnitGraph graph is not mutable.";
110
111
112
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
113
    LOG(FATAL) << "UnitGraph graph is not mutable.";
114
115
116
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
117
    LOG(FATAL) << "UnitGraph graph is not mutable.";
118
119
  }

120
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
121

122
  DGLDataType DataType() const override { return adj_.row->dtype; }
123

124
  DGLContext Context() const override { return adj_.row->ctx; }
125

126
  bool IsPinned() const override { return adj_.is_pinned; }
127

128
  uint8_t NumBits() const override { return adj_.row->dtype.bits; }
129

130
  COO AsNumBits(uint8_t bits) const {
131
    if (NumBits() == bits) return *this;
132
133

    COO ret(
134
135
        meta_graph_, adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));
136
137
138
    return ret;
  }

139
140
  COO CopyTo(const DGLContext& ctx) const {
    if (Context() == ctx) return *this;
141
    return COO(meta_graph_, adj_.CopyTo(ctx));
142
143
  }

144
145
146
147
148
149
150
151
152
  /**
   * @brief Copy the adj_ to pinned memory.
   * @return COOMatrix of the COO graph.
   */
  COO PinMemory() {
    if (adj_.is_pinned) return *this;
    return COO(meta_graph_, adj_.PinMemory());
  }

153
  /** @brief Pin the adj_: COOMatrix of the COO graph. */
154
  void PinMemory_() { adj_.PinMemory_(); }
155

156
  /** @brief Unpin the adj_: COOMatrix of the COO graph. */
157
  void UnpinMemory_() { adj_.UnpinMemory_(); }
158

159
  /** @brief Record stream for the adj_: COOMatrix of the COO graph. */
160
161
162
163
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

164
  bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }
165

166
  bool IsReadonly() const override { return true; }
167
168

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
169
    if (vtype == SrcType()) {
170
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
171
    } else if (vtype == DstType()) {
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

192
193
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
194
195
196
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
197
198
  }

199
200
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
201
202
203
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
204
205
206
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
207
208
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
209
210
211
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
212
213
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
214
215
216
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
217
218
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
219
    return aten::COOGetAllData(adj_, src, dst);
220
221
  }

222
223
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
224
225
226
227
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
228
229
  }

230
231
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
232
233
234
    return aten::COOGetData(adj_, src, dst);
  }

235
236
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
237
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
238
239
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
240
241
242
243
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
244
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
245
246
247
248
249
    BUG_IF_FAIL(aten::IsNullArray(adj_.data))
        << "FindEdges requires the internal COO matrix not having EIDs.";
    return EdgeArray{
        aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
        eids};
250
251
252
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
253
    IdArray ret_src, ret_eid;
254
255
256
257
    std::tie(ret_eid, ret_src) =
        aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);
    IdArray ret_dst =
        aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
258
    return EdgeArray{ret_src, ret_dst, ret_eid};
259
260
261
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
262
263
264
265
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
266
267
268
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
269
270
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
271
272
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
273
    return EdgeArray{ret_src, ret_dst, ret_eid};
274
275
276
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
277
278
279
280
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
281
282
  }

283
284
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
285
    CHECK(order.empty() || order == std::string("eid"))
286
287
        << "COO only support Edges of order \"eid\", but got \"" << order
        << "\".";
288
289
290
291
292
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
293
294
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
295
296
297
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
298
299
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
300
301
302
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
303
304
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
305
306
307
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
308
309
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
333
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
334
335
336
337
338
339
340
341
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

342
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }
343
344
345
346
347
348
349
350
351
352
353

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

354
355
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
356
    LOG(FATAL) << "Not enabled for COO graph";
357
    return SparseFormat::kCOO;
358
359
  }

360
  dgl_format_code_t GetAllowedFormats() const override {
361
    LOG(FATAL) << "Not enabled for COO graph";
362
    return 0;
363
364
  }

365
  dgl_format_code_t GetCreatedFormats() const override {
366
367
368
369
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

370
371
372
373
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
374
375
376
377
378
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
379
    DGLContext ctx = aten::GetContextOf(vids);
380
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
381
382
    subg.graph = std::make_shared<COO>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);
383
384
385
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
386
387
388
  }

  HeteroSubgraph EdgeSubgraph(
389
390
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
391
392
393
394
395
396
397
398
399
400
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
401
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
402
403
404
405
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
406
      subg.induced_vertices.emplace_back(
407
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
408
      subg.induced_vertices.emplace_back(
409
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
410
      subg.graph = std::make_shared<COO>(
411
412
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,
          new_dst);
413
414
415
416
417
      subg.induced_edges = eids;
    }
    return subg;
  }

418
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
419
420
421
422
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

423
  aten::COOMatrix adj() const { return adj_; }
424

425
  /**
426
427
   * @brief Determines whether the graph is "hypersparse", i.e. having
   * significantly more nodes than edges.
428
429
430
431
432
433
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

434
435
436
437
438
439
440
441
442
443
444
445
446
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

447
 private:
448
449
  friend class Serializer;

450
  /** @brief internal adjacency matrix. Data array is empty */
451
452
453
454
455
456
457
458
459
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

460
/** @brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
461
class UnitGraph::CSR : public BaseHeteroGraph {
462
 public:
463
464
465
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,
      IdArray indices, IdArray edge_ids)
      : BaseHeteroGraph(metagraph) {
466
467
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
468
    if (aten::IsValidIdArray(edge_ids))
469
470
471
472
473
      CHECK(
          (indices->shape[0] == edge_ids->shape[0]) ||
          aten::IsNullArray(edge_ids))
          << "edge id arrays should have the same length as indices if not "
             "empty";
474
    CHECK_EQ(num_src, indptr->shape[0] - 1)
475
        << "number of nodes do not match the length of indptr minus 1.";
476

477
478
479
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

480
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
481
      : BaseHeteroGraph(metagraph), adj_(csr) {}
482

483
484
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
485
486
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
487
488
489
490
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

491
  bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }
492

493
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
494

495
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
496

497
  inline dgl_type_t EdgeType() const { return 0; }
498
499

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
500
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
501
               << "The relation graph is simply this graph itself.";
502
503
504
505
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
506
    LOG(FATAL) << "UnitGraph graph is not mutable.";
507
508
509
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
510
    LOG(FATAL) << "UnitGraph graph is not mutable.";
511
512
513
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
514
    LOG(FATAL) << "UnitGraph graph is not mutable.";
515
516
  }

517
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
518

519
  DGLDataType DataType() const override { return adj_.indices->dtype; }
520

521
  DGLContext Context() const override { return adj_.indices->ctx; }
522

523
  bool IsPinned() const override { return adj_.is_pinned; }
524

525
  uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
526

527
528
529
530
531
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
532
          meta_graph_, adj_.num_rows, adj_.num_cols,
533
534
535
536
537
538
539
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

540
  CSR CopyTo(const DGLContext& ctx) const {
541
542
543
    if (Context() == ctx) {
      return *this;
    } else {
544
      return CSR(meta_graph_, adj_.CopyTo(ctx));
545
546
547
    }
  }

548
549
550
551
552
553
554
555
556
  /**
   * @brief Copy the adj_ to pinned memory.
   * @return CSRMatrix of the CSR graph.
   */
  CSR PinMemory() {
    if (adj_.is_pinned) return *this;
    return CSR(meta_graph_, adj_.PinMemory());
  }

557
  /** @brief Pin the adj_: CSRMatrix of the CSR graph. */
558
  void PinMemory_() { adj_.PinMemory_(); }
559

560
  /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
561
  void UnpinMemory_() { adj_.UnpinMemory_(); }
562

563
  /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
564
565
566
567
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

568
  bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }
569

570
  bool IsReadonly() const override { return true; }
571
572

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
573
    if (vtype == SrcType()) {
574
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
575
    } else if (vtype == DstType()) {
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

596
597
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
598
599
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
600
601
602
    return aten::CSRIsNonZero(adj_, src, dst);
  }

603
604
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
605
606
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
607
608
609
610
611
612
613
614
615
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
616
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
617
618
619
620
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
621
622
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
623
    return aten::CSRGetAllData(adj_, src, dst);
624
625
  }

626
627
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
628
629
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
630
631
632
633
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

634
635
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
636
637
638
    return aten::CSRGetData(adj_, src, dst);
  }

639
640
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
641
    LOG(FATAL) << "Not enabled for CSR graph.";
642
643
644
645
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
646
    LOG(FATAL) << "Not enabled for CSR graph.";
647
648
649
650
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
651
    LOG(FATAL) << "Not enabled for CSR graph.";
652
653
654
655
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
656
    LOG(FATAL) << "Not enabled for CSR graph.";
657
658
659
660
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
661
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
662
663
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
664
665
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
666
667
668
669
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
670
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
671
672
673
674
675
676
677
678
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

679
680
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
681
    CHECK(order.empty() || order == std::string("srcdst"))
682
683
        << "CSR only support Edges of order \"srcdst\","
        << " but got \"" << order << "\".";
684
685
686
687
688
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
689
690
691
692
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
693
    LOG(FATAL) << "Not enabled for CSR graph.";
694
695
696
697
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
698
    LOG(FATAL) << "Not enabled for CSR graph.";
699
700
701
702
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
703
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
704
705
706
707
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
708
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
709
710
711
712
713
714
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
715
    CHECK_EQ(NumBits(), 64);
716
717
718
719
720
721
722
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

723
724
725
726
727
728
729
730
731
732
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

733
734
735
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
736
    CHECK_EQ(NumBits(), 64);
737
738
739
740
741
742
743
744
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
745
    LOG(FATAL) << "Not enabled for CSR graph.";
746
747
748
749
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
750
    LOG(FATAL) << "Not enabled for CSR graph.";
751
752
753
754
    return {};
  }

  std::vector<IdArray> GetAdj(
755
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
756
757
758
759
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

760
761
762
763
764
765
766
767
768
769
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

770
  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }
771

772
773
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
774
    LOG(FATAL) << "Not enabled for CSR graph";
775
    return SparseFormat::kCSR;
776
777
  }

778
779
780
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
781
782
  }

783
  dgl_format_code_t GetCreatedFormats() const override {
784
785
786
787
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

788
789
790
791
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
Minjie Wang's avatar
Minjie Wang committed
792
793
794
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
795
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
796
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
797
    DGLContext ctx = aten::GetContextOf(vids);
798
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
799
800
801
    subg.graph = std::make_shared<CSR>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,
        submat.indices, sub_eids);
802
803
804
805
806
807
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
808
809
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
810
    LOG(FATAL) << "Not enabled for CSR graph.";
811
812
813
    return {};
  }

814
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
815
816
817
818
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

819
  aten::CSRMatrix adj() const { return adj_; }
820

821
822
823
824
825
826
827
828
829
830
831
832
833
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

834
 private:
835
836
  friend class Serializer;

837
  /** @brief internal adjacency matrix. Data array stores edge ids */
838
839
840
841
842
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
843
// unit graph implementation
844
845
846
//
//////////////////////////////////////////////////////////

847
DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }
848

849
DGLContext UnitGraph::Context() const { return GetAny()->Context(); }
850

851
bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }
852

853
uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }
854

Minjie Wang's avatar
Minjie Wang committed
855
bool UnitGraph::IsMultigraph() const {
856
  const SparseFormat fmt = SelectFormat(CSC_CODE);
857
858
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
859
860
}

Minjie Wang's avatar
Minjie Wang committed
861
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
862
  const SparseFormat fmt = SelectFormat(ALL_CODE);
863
864
865
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
866
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
867
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
868
  return ptr->NumVertices(vtype);
869
870
}

Minjie Wang's avatar
Minjie Wang committed
871
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
872
873
874
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
875
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
876
  const SparseFormat fmt = SelectFormat(ALL_CODE);
877
  const auto ptr = GetFormat(fmt);
878
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
879
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
880
  return ptr->HasVertex(vtype, vid);
881
882
}

Minjie Wang's avatar
Minjie Wang committed
883
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
884
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
885
886
887
  return aten::LT(vids, NumVertices(vtype));
}

888
889
bool UnitGraph::HasEdgeBetween(
    dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
890
  const SparseFormat fmt = SelectFormat(CSC_CODE);
891
  const auto ptr = GetFormat(fmt);
892
  if (fmt == SparseFormat::kCSC)
893
894
895
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
896
897
}

Minjie Wang's avatar
Minjie Wang committed
898
BoolArray UnitGraph::HasEdgesBetween(
899
    dgl_type_t etype, IdArray src, IdArray dst) const {
900
  const SparseFormat fmt = SelectFormat(CSC_CODE);
901
  const auto ptr = GetFormat(fmt);
902
  if (fmt == SparseFormat::kCSC)
903
904
905
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
906
907
}

Minjie Wang's avatar
Minjie Wang committed
908
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
909
  const SparseFormat fmt = SelectFormat(CSC_CODE);
910
  const auto ptr = GetFormat(fmt);
911
  if (fmt == SparseFormat::kCSC)
912
913
914
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
915
916
}

Minjie Wang's avatar
Minjie Wang committed
917
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
918
  const SparseFormat fmt = SelectFormat(CSR_CODE);
919
920
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
921
922
}

Minjie Wang's avatar
Minjie Wang committed
923
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
924
  const SparseFormat fmt = SelectFormat(CSR_CODE);
925
  const auto ptr = GetFormat(fmt);
926
  if (fmt == SparseFormat::kCSC)
927
928
929
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
930
931
}

932
933
EdgeArray UnitGraph::EdgeIdsAll(
    dgl_type_t etype, IdArray src, IdArray dst) const {
934
  const SparseFormat fmt = SelectFormat(CSR_CODE);
935
  const auto ptr = GetFormat(fmt);
936
  if (fmt == SparseFormat::kCSC) {
937
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
938
939
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
940
941
942
943
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

944
945
IdArray UnitGraph::EdgeIdsOne(
    dgl_type_t etype, IdArray src, IdArray dst) const {
946
  const SparseFormat fmt = SelectFormat(CSR_CODE);
947
948
949
950
951
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
952
953
954
  }
}

955
956
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(
    dgl_type_t etype, dgl_id_t eid) const {
957
  const SparseFormat fmt = SelectFormat(COO_CODE);
958
959
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
960
961
}

Minjie Wang's avatar
Minjie Wang committed
962
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
963
  const SparseFormat fmt = SelectFormat(COO_CODE);
964
965
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
966
967
}

Minjie Wang's avatar
Minjie Wang committed
968
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
969
  const SparseFormat fmt = SelectFormat(CSC_CODE);
970
  const auto ptr = GetFormat(fmt);
971
  if (fmt == SparseFormat::kCSC) {
972
973
974
975
976
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
977
978
}

Minjie Wang's avatar
Minjie Wang committed
979
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
980
  const SparseFormat fmt = SelectFormat(CSC_CODE);
981
  const auto ptr = GetFormat(fmt);
982
  if (fmt == SparseFormat::kCSC) {
983
984
985
986
987
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
988
989
}

Minjie Wang's avatar
Minjie Wang committed
990
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
991
  const SparseFormat fmt = SelectFormat(CSR_CODE);
992
993
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
994
995
}

Minjie Wang's avatar
Minjie Wang committed
996
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
997
  const SparseFormat fmt = SelectFormat(CSR_CODE);
998
999
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
1000
1001
}

1002
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {
1003
1004
  SparseFormat fmt;
  if (order == std::string("eid")) {
1005
    fmt = SelectFormat(COO_CODE);
1006
  } else if (order.empty()) {
1007
    // arbitrary order
1008
    fmt = SelectFormat(ALL_CODE);
1009
  } else if (order == std::string("srcdst")) {
1010
    fmt = SelectFormat(CSR_CODE);
1011
1012
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
1013
    return {};
1014
  }
1015
1016

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
1017
  if (fmt == SparseFormat::kCSC)
1018
1019
1020
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1021
1022
}

Minjie Wang's avatar
Minjie Wang committed
1023
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1024
  SparseFormat fmt = SelectFormat(CSC_CODE);
1025
  const auto ptr = GetFormat(fmt);
1026
1027
1028
1029
1030
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1031
1032
}

Minjie Wang's avatar
Minjie Wang committed
1033
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1034
  SparseFormat fmt = SelectFormat(CSC_CODE);
1035
  const auto ptr = GetFormat(fmt);
1036
1037
1038
1039
1040
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1041
1042
}

Minjie Wang's avatar
Minjie Wang committed
1043
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1044
  SparseFormat fmt = SelectFormat(CSR_CODE);
1045
  const auto ptr = GetFormat(fmt);
1046
1047
1048
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1049
  return ptr->OutDegree(etype, vid);
1050
1051
}

Minjie Wang's avatar
Minjie Wang committed
1052
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1053
  SparseFormat fmt = SelectFormat(CSR_CODE);
1054
  const auto ptr = GetFormat(fmt);
1055
1056
1057
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1058
  return ptr->OutDegrees(etype, vids);
1059
1060
}

Minjie Wang's avatar
Minjie Wang committed
1061
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1062
  SparseFormat fmt = SelectFormat(CSR_CODE);
1063
1064
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1065
1066
}

1067
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1068
  SparseFormat fmt = SelectFormat(CSR_CODE);
1069
1070
1071
1072
1073
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1074
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1075
  SparseFormat fmt = SelectFormat(CSR_CODE);
1076
1077
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1078
1079
}

Minjie Wang's avatar
Minjie Wang committed
1080
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1081
  SparseFormat fmt = SelectFormat(CSC_CODE);
1082
  const auto ptr = GetFormat(fmt);
1083
  if (fmt == SparseFormat::kCSC)
1084
1085
1086
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1087
1088
}

Minjie Wang's avatar
Minjie Wang committed
1089
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1090
  SparseFormat fmt = SelectFormat(CSC_CODE);
1091
  const auto ptr = GetFormat(fmt);
1092
  if (fmt == SparseFormat::kCSC)
1093
1094
1095
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1096
1097
}

Minjie Wang's avatar
Minjie Wang committed
1098
std::vector<IdArray> UnitGraph::GetAdj(
1099
1100
1101
1102
1103
1104
1105
1106
1107
    dgl_type_t etype, bool transpose, const std::string& fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst
  // nodes and col for src nodes. Therefore, we need to flip the transpose flag.
  // For example,
  //   transpose=False is equal to in edge CSR. We have this behavior because
  //   previously we use framework's SPMM and we don't cache reverse adj. This
  //   is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
  //   change the behavior and make row for src and col for dst.
1108
  if (fmt == std::string("csr")) {
1109
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1110
                      : GetInCSR()->GetAdj(etype, false, "csr");
1111
  } else if (fmt == std::string("coo")) {
1112
    return GetCOO()->GetAdj(etype, transpose, fmt);
1113
1114
1115
1116
1117
1118
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

1119
1120
HeteroSubgraph UnitGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
1121
  // We prefer to generate a subgraph from out-csr.
1122
  SparseFormat fmt = SelectFormat(CSR_CODE);
1123
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1124
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1144
1145
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1146
1147
1148
1149
1150
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1151
HeteroSubgraph UnitGraph::EdgeSubgraph(
1152
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1153
  SparseFormat fmt = SelectFormat(COO_CODE);
1154
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1155
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1175
1176
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1177
1178
1179
1180
1181
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1182
HeteroGraphPtr UnitGraph::CreateFromCOO(
1183
1184
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
    IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1185
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1186
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1187
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1188
  COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));
1189

1190
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1191
1192
}

1193
HeteroGraphPtr UnitGraph::CreateFromCOO(
1194
    int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
1195
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1196
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1197
1198
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1199

1200
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1201
1202
}

Minjie Wang's avatar
Minjie Wang committed
1203
HeteroGraphPtr UnitGraph::CreateFromCSR(
1204
1205
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1206
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1207
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1208
1209
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1210
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1211
1212
}

1213
HeteroGraphPtr UnitGraph::CreateFromCSR(
1214
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1215
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1216
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1217
1218
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1219
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1220
1221
}

1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO(
    int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo,
    dgl_format_code_t formats) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  CHECK_EQ(coo.num_rows, csr.num_rows);
  CHECK_EQ(coo.num_cols, csr.num_cols);
  if (num_vtypes == 1) {
    CHECK_EQ(csr.num_rows, csr.num_cols);
  }
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csrPtr(new CSR(mg, csr));
  COOPtr cooPtr(new COO(mg, coo));
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats));
}

1237
HeteroGraphPtr UnitGraph::CreateFromCSC(
1238
1239
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1240
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1241
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
1242
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1243
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1244
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1245
1246
1247
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
1248
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1249
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1250
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1251
1252
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1253
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1254
1255
}

1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO(
    int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo,
    dgl_format_code_t formats) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  CHECK_EQ(coo.num_rows, csc.num_cols);
  CHECK_EQ(coo.num_cols, csc.num_rows);
  if (num_vtypes == 1) {
    CHECK_EQ(csc.num_rows, csc.num_cols);
  }
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr cscPtr(new CSR(mg, csc));
  COOPtr cooPtr(new COO(mg, coo));
  return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats));
}

Minjie Wang's avatar
Minjie Wang committed
1271
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1272
1273
1274
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1275
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1276
    CHECK_NOTNULL(bg);
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    CSRPtr new_incsr = (bg->in_csr_->defined())
                           ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
                            ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
                         ? COOPtr(new COO(bg->coo_->AsNumBits(bits)))
                         : nullptr;
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1288
1289
1290
  }
}

1291
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
1292
1293
  if (ctx == g->Context()) {
    return g;
1294
1295
1296
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1297
    CSRPtr new_incsr = (bg->in_csr_->defined())
1298
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
1299
1300
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
1301
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
1302
1303
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
1304
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
1305
                         : nullptr;
1306
1307
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1308
1309
1310
  }
}

1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
HeteroGraphPtr UnitGraph::PinMemory() {
  CSRPtr pinned_in_csr, pinned_out_csr;
  COOPtr pinned_coo;
  if (this->in_csr_->defined() && this->in_csr_->IsPinned()) {
    pinned_in_csr = this->in_csr_;
  } else if (this->in_csr_->defined()) {
    pinned_in_csr = CSRPtr(new CSR(this->in_csr_->PinMemory()));
  } else {
    pinned_in_csr = nullptr;
  }

  if (this->out_csr_->defined() && this->out_csr_->IsPinned()) {
    pinned_out_csr = this->out_csr_;
  } else if (this->out_csr_->defined()) {
    pinned_out_csr = CSRPtr(new CSR(this->out_csr_->PinMemory()));
  } else {
    pinned_out_csr = nullptr;
  }

  if (this->coo_->defined() && this->coo_->IsPinned()) {
    pinned_coo = this->coo_;
  } else if (this->coo_->defined()) {
    pinned_coo = COOPtr(new COO(this->coo_->PinMemory()));
  } else {
    pinned_coo = nullptr;
  }

  return HeteroGraphPtr(new UnitGraph(
      meta_graph(), pinned_in_csr, pinned_out_csr, pinned_coo, this->formats_));
}

1342
void UnitGraph::PinMemory_() {
1343
1344
1345
  if (this->in_csr_->defined()) this->in_csr_->PinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->PinMemory_();
  if (this->coo_->defined()) this->coo_->PinMemory_();
1346
1347
1348
}

void UnitGraph::UnpinMemory_() {
1349
1350
1351
  if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();
  if (this->coo_->defined()) this->coo_->UnpinMemory_();
1352
1353
}

1354
void UnitGraph::RecordStream(DGLStreamHandle stream) {
1355
1356
1357
  if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);
  if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);
  if (this->coo_->defined()) this->coo_->RecordStream(stream);
1358
1359
1360
  this->recorded_streams.push_back(stream);
}

1361
void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }
1362

1363
void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }
1364

1365
void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }
1366

1367
1368
1369
1370
1371
1372
1373
UnitGraph::UnitGraph(
    GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
    dgl_format_code_t formats)
    : BaseHeteroGraph(metagraph),
      in_csr_(in_csr),
      out_csr_(out_csr),
      coo_(coo) {
1374
1375
1376
1377
1378
1379
1380
1381
1382
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1383
1384
1385
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
1386
1387
1388
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created)
               << ", which is not compatible with available formats: "
               << CodeToStr(formats);
1389
1390
1391
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1392
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
1393
1394
1395
    int num_vtypes, const aten::CSRMatrix& in_csr,
    const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,
    bool has_out_csr, bool has_coo, dgl_format_code_t formats) {
1396
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1397
1398
1399
1400
1401
1402
1403

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1404
1405
  else
    in_csr_ptr = CSRPtr(new CSR());
1406
1407
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1408
1409
  else
    out_csr_ptr = CSRPtr(new CSR());
1410
1411
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1412
1413
  else
    coo_ptr = COOPtr(new COO());
1414

1415
1416
  return HeteroGraphPtr(
      new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1417
1418
}

1419
1420
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1421
    if (!(formats_ & CSC_CODE))
1422
1423
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSC matrix.";
1424
  CSRPtr ret = in_csr_;
1425
1426
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1427
  if (!in_csr_->defined()) {
1428
    if (coo_->defined()) {
1429
      const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));
1430

1431
      if (inplace)
1432
1433
1434
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1435
    } else {
1436
1437
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1438

1439
      if (inplace)
1440
1441
1442
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1443
    }
1444
    if (inplace) {
1445
1446
      if (IsPinned()) in_csr_->PinMemory_();
      for (auto stream : recorded_streams) in_csr_->RecordStream(stream);
1447
    }
1448
  }
1449
  return ret;
1450
1451
}

1452
/** @brief Return out csr. If not exist, transpose the other one.*/
1453
1454
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1455
    if (!(formats_ & CSR_CODE))
1456
1457
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSR matrix.";
1458
  CSRPtr ret = out_csr_;
1459
1460
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1461
  if (!out_csr_->defined()) {
1462
1463
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1464

1465
      if (inplace)
1466
1467
1468
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1469
    } else {
1470
1471
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1472

1473
      if (inplace)
1474
1475
1476
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1477
    }
1478
    if (inplace) {
1479
1480
      if (IsPinned()) out_csr_->PinMemory_();
      for (auto stream : recorded_streams) out_csr_->RecordStream(stream);
1481
    }
1482
  }
1483
  return ret;
1484
1485
}

1486
/** @brief Return coo. If not exist, create from csr.*/
1487
1488
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1489
    if (!(formats_ & COO_CODE))
1490
1491
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create COO matrix.";
1492
  COOPtr ret = coo_;
1493
1494
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1495
1496
      const auto& newadj =
          aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1497

1498
      if (inplace)
1499
1500
1501
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1502
    } else {
1503
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1504
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1505

1506
      if (inplace)
1507
1508
1509
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1510
    }
1511
    if (inplace) {
1512
1513
      if (IsPinned()) coo_->PinMemory_();
      for (auto stream : recorded_streams) coo_->RecordStream(stream);
1514
    }
1515
  }
1516
  return ret;
1517
1518
}

1519
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1520
1521
1522
  return GetInCSR()->adj();
}

1523
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1524
1525
1526
  return GetOutCSR()->adj();
}

1527
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1528
1529
1530
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1531
HeteroGraphPtr UnitGraph::GetAny() const {
1532
  if (in_csr_->defined()) {
1533
    return in_csr_;
1534
  } else if (out_csr_->defined()) {
1535
1536
1537
1538
1539
1540
    return out_csr_;
  } else {
    return coo_;
  }
}

1541
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1542
  dgl_format_code_t ret = 0;
1543
1544
1545
  if (in_csr_->defined()) ret |= CSC_CODE;
  if (out_csr_->defined()) ret |= CSR_CODE;
  if (coo_->defined()) ret |= COO_CODE;
1546
1547
1548
  return ret;
}

1549
dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }
1550

1551
1552
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1553
1554
1555
1556
1557
1558
    case SparseFormat::kCSR:
      return GetOutCSR();
    case SparseFormat::kCSC:
      return GetInCSR();
    default:
      return GetCOO();
1559
1560
1561
  }
}

1562
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
  // Get the created formats.
  auto created_formats = GetCreatedFormats();
  // Get the intersection of formats and created_formats.
  auto intersection = formats & created_formats;

  // If the intersection of formats and created_formats is not empty.
  // The format(s) in the intersection will be retained.
  if (intersection != 0) {
    COOPtr coo_ptr = COO_CODE & intersection ? GetCOO(false) : nullptr;
    CSRPtr in_csr_ptr = CSC_CODE & intersection ? GetInCSR(false) : nullptr;
    CSRPtr out_csr_ptr = CSR_CODE & intersection ? GetOutCSR(false) : nullptr;

1575
    return HeteroGraphPtr(
1576
1577
1578
1579
1580
        new UnitGraph(meta_graph_, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
  }

  // If the intersection of formats and created_formats is empty.
  // Create a format in the order of COO -> CSR -> CSC.
1581
  int64_t num_vtypes = NumVertexTypes();
1582
  if (COO_CODE & formats)
1583
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1584
  if (CSR_CODE & formats)
1585
1586
1587
1588
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

1589
1590
SparseFormat UnitGraph::SelectFormat(
    dgl_format_code_t preferred_formats) const {
1591
1592
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
1593
  if (common & created) return DecodeFormat(common & created);
1594

1595
1596
1597
  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on
  // COO have not been implmented yet. if (coo_->defined() &&
  // coo_->IsHypersparse())  // only allow coo for hypersparse graph.
1598
  //   return SparseFormat::kCOO;
1599
  if (common) return DecodeFormat(common);
1600
  return DecodeFormat(created);
1601
1602
}

1603
1604
1605
1606
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1607
  if (in_csr_->defined()) {
1608
    aten::CSRMatrix csc = GetCSCMatrix(0);
1609
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1610
  }
1611
  if (out_csr_->defined()) {
1612
    aten::CSRMatrix csr = GetCSRMatrix(0);
1613
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1614
  }
1615
  if (coo_->defined()) {
1616
1617
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1618
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1619
1620
1621
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1622
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1623
1624
1625
1626
1627
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1628
1629
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1630
  auto fmt = SelectFormat(ALL_CODE);
1631
1632
  switch (fmt) {
    case SparseFormat::kCOO: {
1633
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1634
1635
1636
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
1637
1638
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1639
      return CreateFromCOO(1, coo);
1640
1641
1642
1643
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
1644
1645
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1646
      return CreateFromCOO(1, coo);
1647
1648
1649
1650
1651
1652
1653
1654
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1655
1656
1657
1658
1659
1660
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1661

1662
  int64_t save_format_code, formats_code;
1663
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1664
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1665
1666
1667
1668
1669
1670
1671
1672
  dgl_format_code_t save_formats = ANY_CODE;
  if (save_format_code >> 32) {
    save_formats =
        static_cast<dgl_format_code_t>(0xffffffff & save_format_code);
  } else {
    save_formats =
        SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});
  }
1673
1674
1675
1676
1677
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
      case 0:
        formats_ = ALL_CODE;
        break;
      case 1:
        formats_ = COO_CODE;
        break;
      case 2:
        formats_ = CSR_CODE;
        break;
      case 3:
        formats_ = CSC_CODE;
        break;
      default:
        LOG(FATAL) << "Load graph failed, formats code " << formats_code
                   << "not recognized.";
1693
1694
    }
  }
1695

1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
  if (save_formats & COO_CODE) {
    fs->Read(&coo_);
  }
  if (save_formats & CSR_CODE) {
    fs->Read(&out_csr_);
  }
  if (save_formats & CSC_CODE) {
    fs->Read(&in_csr_);
  }
  if (!coo_ && !out_csr_ && !in_csr_) {
    LOG(FATAL) << "unsupported format code";
1707
1708
  }

1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1719
1720
  meta_graph_ = GetAny()->meta_graph();

1721
1722
1723
1724
1725
  return true;
}

void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1726
1727
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1728
  auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
1729
  auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);
1730
1731
1732
1733
1734
1735
1736
  if (fstream) {
    auto formats = fstream->FormatsToSave();
    save_formats = formats == ANY_CODE
                       ? SparseFormatsToCode({SelectFormat(ALL_CODE)})
                       : formats;
  }
  fs->Write(static_cast<int64_t>(save_formats | 0x100000000));
1737
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1738
1739
1740
1741
1742
1743
1744
1745
  if (save_formats & COO_CODE) {
    fs->Write(GetCOO());
  }
  if (save_formats & CSR_CODE) {
    fs->Write(GetOutCSR());
  }
  if (save_formats & CSC_CODE) {
    fs->Write(GetInCSR());
1746
  }
1747
1748
}

1749
1750
1751
1752
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
1753
1754
    new_coo =
        COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
1755
1756
  }

1757
1758
  return UnitGraphPtr(
      new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
1759
1760
}

1761
std::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {
1762
1763
1764
1765
1766
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1767
  auto avail_fmt = SelectFormat(ALL_CODE);
1768
1769
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1770
      auto ret = aten::COOToSimple(GetCOO()->adj());
1771
1772
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1773
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1774
1775
1776
      break;
    }
    case SparseFormat::kCSR: {
1777
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1778
1779
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1780
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1781
1782
1783
      break;
    }
    case SparseFormat::kCSC: {
1784
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1785
1786
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1787
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1788
1789
1790
1791
1792
1793
1794
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

1795
1796
1797
  return std::make_tuple(
      UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
      count, edge_map);
1798
1799
}

1800
}  // namespace dgl