unit_graph.cc 59.1 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/unit_graph.cc
 * @brief UnitGraph graph implementation
5
 */
6
7
#include "./unit_graph.h"

8
#include <dgl/array.h>
9
#include <dgl/base_heterograph.h>
10
11
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
12
13

#include "../c_api_common.h"
14
#include "./serialize/dglstream.h"
15
16

namespace dgl {
17

18
namespace {
19
20
21

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
22
23
24
25
26
27
28
29
30
31
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
32

Minjie Wang's avatar
Minjie Wang committed
33
34
35
36
37
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
38
39
40
41
42
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
43
44
45
46
47
48
49
50
51
52
53
54

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
55
56

};  // namespace
57
58
59
60
61
62
63

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
64
class UnitGraph::COO : public BaseHeteroGraph {
65
 public:
66
67
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
68
      : BaseHeteroGraph(metagraph) {
69
70
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
71
72
73
74
    CHECK_EQ(src->shape[0], dst->shape[0])
        << "Input arrays should have the same length.";
    adj_ = aten::COOMatrix{num_src,     num_dst,    src,       dst,
                           NullArray(), row_sorted, col_sorted};
75
  }
76

77
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
78
      : BaseHeteroGraph(metagraph), adj_(coo) {
79
80
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
81
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
82
    adj_.data = aten::NullArray();
83
  }
84

85
86
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
87
88
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
89
90
91
92
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

93
  bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }
94

95
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
96

97
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
98

99
  inline dgl_type_t EdgeType() const { return 0; }
100
101

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
102
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
103
               << "The relation graph is simply this graph itself.";
104
105
106
107
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
108
    LOG(FATAL) << "UnitGraph graph is not mutable.";
109
110
111
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
112
    LOG(FATAL) << "UnitGraph graph is not mutable.";
113
114
115
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
116
    LOG(FATAL) << "UnitGraph graph is not mutable.";
117
118
  }

119
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
120

121
  DGLDataType DataType() const override { return adj_.row->dtype; }
122

123
  DGLContext Context() const override { return adj_.row->ctx; }
124

125
  bool IsPinned() const override { return adj_.is_pinned; }
126

127
  uint8_t NumBits() const override { return adj_.row->dtype.bits; }
128

129
  COO AsNumBits(uint8_t bits) const {
130
    if (NumBits() == bits) return *this;
131
132

    COO ret(
133
134
        meta_graph_, adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));
135
136
137
    return ret;
  }

138
139
  COO CopyTo(const DGLContext& ctx) const {
    if (Context() == ctx) return *this;
140
    return COO(meta_graph_, adj_.CopyTo(ctx));
141
142
  }

143
144
145
146
147
148
149
150
151
  /**
   * @brief Copy the adj_ to pinned memory.
   * @return COOMatrix of the COO graph.
   */
  COO PinMemory() {
    if (adj_.is_pinned) return *this;
    return COO(meta_graph_, adj_.PinMemory());
  }

152
  /** @brief Pin the adj_: COOMatrix of the COO graph. */
153
  void PinMemory_() { adj_.PinMemory_(); }
154

155
  /** @brief Unpin the adj_: COOMatrix of the COO graph. */
156
  void UnpinMemory_() { adj_.UnpinMemory_(); }
157

158
  /** @brief Record stream for the adj_: COOMatrix of the COO graph. */
159
160
161
162
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

163
  bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }
164

165
  bool IsReadonly() const override { return true; }
166
167

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
168
    if (vtype == SrcType()) {
169
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
170
    } else if (vtype == DstType()) {
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

191
192
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
193
194
195
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
196
197
  }

198
199
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
200
201
202
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
203
204
205
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
206
207
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
208
209
210
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
211
212
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
213
214
215
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
216
217
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
218
    return aten::COOGetAllData(adj_, src, dst);
219
220
  }

221
222
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
223
224
225
226
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
227
228
  }

229
230
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
231
232
233
    return aten::COOGetData(adj_, src, dst);
  }

234
235
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
236
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
237
238
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
239
240
241
242
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
243
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
244
245
246
247
248
    BUG_IF_FAIL(aten::IsNullArray(adj_.data))
        << "FindEdges requires the internal COO matrix not having EIDs.";
    return EdgeArray{
        aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
        eids};
249
250
251
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
252
    IdArray ret_src, ret_eid;
253
254
255
256
    std::tie(ret_eid, ret_src) =
        aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);
    IdArray ret_dst =
        aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
257
    return EdgeArray{ret_src, ret_dst, ret_eid};
258
259
260
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
261
262
263
264
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
265
266
267
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
268
269
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
270
271
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
272
    return EdgeArray{ret_src, ret_dst, ret_eid};
273
274
275
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
276
277
278
279
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
280
281
  }

282
283
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
284
    CHECK(order.empty() || order == std::string("eid"))
285
286
        << "COO only support Edges of order \"eid\", but got \"" << order
        << "\".";
287
288
289
290
291
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
292
293
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
294
295
296
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
297
298
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
299
300
301
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
302
303
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
304
305
306
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
307
308
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
332
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
333
334
335
336
337
338
339
340
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

341
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }
342
343
344
345
346
347
348
349
350
351
352

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

353
354
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
355
    LOG(FATAL) << "Not enabled for COO graph";
356
    return SparseFormat::kCOO;
357
358
  }

359
  dgl_format_code_t GetAllowedFormats() const override {
360
    LOG(FATAL) << "Not enabled for COO graph";
361
    return 0;
362
363
  }

364
  dgl_format_code_t GetCreatedFormats() const override {
365
366
367
368
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

369
370
371
372
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
373
374
375
376
377
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
378
    DGLContext ctx = aten::GetContextOf(vids);
379
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
380
381
    subg.graph = std::make_shared<COO>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);
382
383
384
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
385
386
387
  }

  HeteroSubgraph EdgeSubgraph(
388
389
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
390
391
392
393
394
395
396
397
398
399
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
400
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
401
402
403
404
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
405
      subg.induced_vertices.emplace_back(
406
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
407
      subg.induced_vertices.emplace_back(
408
          aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
409
      subg.graph = std::make_shared<COO>(
410
411
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,
          new_dst);
412
413
414
415
416
      subg.induced_edges = eids;
    }
    return subg;
  }

417
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
418
419
420
421
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

422
  aten::COOMatrix adj() const { return adj_; }
423

424
  /**
425
426
   * @brief Determines whether the graph is "hypersparse", i.e. having
   * significantly more nodes than edges.
427
428
429
430
431
432
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

433
434
435
436
437
438
439
440
441
442
443
444
445
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

446
 private:
447
448
  friend class Serializer;

449
  /** @brief internal adjacency matrix. Data array is empty */
450
451
452
453
454
455
456
457
458
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

459
/** @brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
460
class UnitGraph::CSR : public BaseHeteroGraph {
461
 public:
462
463
464
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,
      IdArray indices, IdArray edge_ids)
      : BaseHeteroGraph(metagraph) {
465
466
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
467
    if (aten::IsValidIdArray(edge_ids))
468
469
470
471
472
      CHECK(
          (indices->shape[0] == edge_ids->shape[0]) ||
          aten::IsNullArray(edge_ids))
          << "edge id arrays should have the same length as indices if not "
             "empty";
473
    CHECK_EQ(num_src, indptr->shape[0] - 1)
474
        << "number of nodes do not match the length of indptr minus 1.";
475

476
477
478
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

479
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
480
      : BaseHeteroGraph(metagraph), adj_(csr) {}
481

482
483
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
484
485
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
    // supported
486
487
488
489
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

490
  bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }
491

492
  inline dgl_type_t SrcType() const { return 0; }
Minjie Wang's avatar
Minjie Wang committed
493

494
  inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
Minjie Wang's avatar
Minjie Wang committed
495

496
  inline dgl_type_t EdgeType() const { return 0; }
497
498

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
499
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
500
               << "The relation graph is simply this graph itself.";
501
502
503
504
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
505
    LOG(FATAL) << "UnitGraph graph is not mutable.";
506
507
508
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
509
    LOG(FATAL) << "UnitGraph graph is not mutable.";
510
511
512
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
513
    LOG(FATAL) << "UnitGraph graph is not mutable.";
514
515
  }

516
  void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
517

518
  DGLDataType DataType() const override { return adj_.indices->dtype; }
519

520
  DGLContext Context() const override { return adj_.indices->ctx; }
521

522
  bool IsPinned() const override { return adj_.is_pinned; }
523

524
  uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
525

526
527
528
529
530
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
531
          meta_graph_, adj_.num_rows, adj_.num_cols,
532
533
534
535
536
537
538
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

539
  CSR CopyTo(const DGLContext& ctx) const {
540
541
542
    if (Context() == ctx) {
      return *this;
    } else {
543
      return CSR(meta_graph_, adj_.CopyTo(ctx));
544
545
546
    }
  }

547
548
549
550
551
552
553
554
555
  /**
   * @brief Copy the adj_ to pinned memory.
   * @return CSRMatrix of the CSR graph.
   */
  CSR PinMemory() {
    if (adj_.is_pinned) return *this;
    return CSR(meta_graph_, adj_.PinMemory());
  }

556
  /** @brief Pin the adj_: CSRMatrix of the CSR graph. */
557
  void PinMemory_() { adj_.PinMemory_(); }
558

559
  /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
560
  void UnpinMemory_() { adj_.UnpinMemory_(); }
561

562
  /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
563
564
565
566
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

567
  bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }
568

569
  bool IsReadonly() const override { return true; }
570
571

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
572
    if (vtype == SrcType()) {
573
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
574
    } else if (vtype == DstType()) {
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

595
596
  bool HasEdgeBetween(
      dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
597
598
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
599
600
601
    return aten::CSRIsNonZero(adj_, src, dst);
  }

602
603
  BoolArray HasEdgesBetween(
      dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
604
605
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
606
607
608
609
610
611
612
613
614
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
615
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
616
617
618
619
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
620
621
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
622
    return aten::CSRGetAllData(adj_, src, dst);
623
624
  }

625
626
  EdgeArray EdgeIdsAll(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
627
628
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
629
630
631
632
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

633
634
  IdArray EdgeIdsOne(
      dgl_type_t etype, IdArray src, IdArray dst) const override {
635
636
637
    return aten::CSRGetData(adj_, src, dst);
  }

638
639
  std::pair<dgl_id_t, dgl_id_t> FindEdge(
      dgl_type_t etype, dgl_id_t eid) const override {
640
    LOG(FATAL) << "Not enabled for CSR graph.";
641
642
643
644
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
645
    LOG(FATAL) << "Not enabled for CSR graph.";
646
647
648
649
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
650
    LOG(FATAL) << "Not enabled for CSR graph.";
651
652
653
654
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
655
    LOG(FATAL) << "Not enabled for CSR graph.";
656
657
658
659
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
660
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
661
662
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
663
664
    IdArray ret_src =
        aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
665
666
667
668
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
669
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
670
671
672
673
674
675
676
677
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

678
679
  EdgeArray Edges(
      dgl_type_t etype, const std::string& order = "") const override {
680
    CHECK(order.empty() || order == std::string("srcdst"))
681
682
        << "CSR only support Edges of order \"srcdst\","
        << " but got \"" << order << "\".";
683
684
685
686
687
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
688
689
690
691
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
692
    LOG(FATAL) << "Not enabled for CSR graph.";
693
694
695
696
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
697
    LOG(FATAL) << "Not enabled for CSR graph.";
698
699
700
701
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
702
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
703
704
705
706
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
707
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
708
709
710
711
712
713
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
714
    CHECK_EQ(NumBits(), 64);
715
716
717
718
719
720
721
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

722
723
724
725
726
727
728
729
730
731
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

732
733
734
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
735
    CHECK_EQ(NumBits(), 64);
736
737
738
739
740
741
742
743
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
744
    LOG(FATAL) << "Not enabled for CSR graph.";
745
746
747
748
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
749
    LOG(FATAL) << "Not enabled for CSR graph.";
750
751
752
753
    return {};
  }

  std::vector<IdArray> GetAdj(
754
      dgl_type_t etype, bool transpose, const std::string& fmt) const override {
755
756
757
758
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

759
760
761
762
763
764
765
766
767
768
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

769
  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }
770

771
772
  SparseFormat SelectFormat(
      dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
773
    LOG(FATAL) << "Not enabled for CSR graph";
774
    return SparseFormat::kCSR;
775
776
  }

777
778
779
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
780
781
  }

782
  dgl_format_code_t GetCreatedFormats() const override {
783
784
785
786
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

787
788
789
790
  HeteroSubgraph VertexSubgraph(
      const std::vector<IdArray>& vids) const override {
    CHECK_EQ(vids.size(), NumVertexTypes())
        << "Number of vertex types mismatch";
Minjie Wang's avatar
Minjie Wang committed
791
792
793
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
794
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
795
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
796
    DGLContext ctx = aten::GetContextOf(vids);
797
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
798
799
800
    subg.graph = std::make_shared<CSR>(
        meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,
        submat.indices, sub_eids);
801
802
803
804
805
806
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
807
808
      const std::vector<IdArray>& eids,
      bool preserve_nodes = false) const override {
809
    LOG(FATAL) << "Not enabled for CSR graph.";
810
811
812
    return {};
  }

813
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
814
815
816
817
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

818
  aten::CSRMatrix adj() const { return adj_; }
819

820
821
822
823
824
825
826
827
828
829
830
831
832
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

833
 private:
834
835
  friend class Serializer;

836
  /** @brief internal adjacency matrix. Data array stores edge ids */
837
838
839
840
841
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
842
// unit graph implementation
843
844
845
//
//////////////////////////////////////////////////////////

846
DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }
847

848
DGLContext UnitGraph::Context() const { return GetAny()->Context(); }
849

850
bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }
851

852
uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }
853

Minjie Wang's avatar
Minjie Wang committed
854
bool UnitGraph::IsMultigraph() const {
855
  const SparseFormat fmt = SelectFormat(CSC_CODE);
856
857
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
858
859
}

Minjie Wang's avatar
Minjie Wang committed
860
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
861
  const SparseFormat fmt = SelectFormat(ALL_CODE);
862
863
864
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
865
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
866
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
867
  return ptr->NumVertices(vtype);
868
869
}

Minjie Wang's avatar
Minjie Wang committed
870
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
871
872
873
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
874
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
875
  const SparseFormat fmt = SelectFormat(ALL_CODE);
876
  const auto ptr = GetFormat(fmt);
877
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
878
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
879
  return ptr->HasVertex(vtype, vid);
880
881
}

Minjie Wang's avatar
Minjie Wang committed
882
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
883
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
884
885
886
  return aten::LT(vids, NumVertices(vtype));
}

887
888
bool UnitGraph::HasEdgeBetween(
    dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
889
  const SparseFormat fmt = SelectFormat(CSC_CODE);
890
  const auto ptr = GetFormat(fmt);
891
  if (fmt == SparseFormat::kCSC)
892
893
894
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
895
896
}

Minjie Wang's avatar
Minjie Wang committed
897
BoolArray UnitGraph::HasEdgesBetween(
898
    dgl_type_t etype, IdArray src, IdArray dst) const {
899
  const SparseFormat fmt = SelectFormat(CSC_CODE);
900
  const auto ptr = GetFormat(fmt);
901
  if (fmt == SparseFormat::kCSC)
902
903
904
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
905
906
}

Minjie Wang's avatar
Minjie Wang committed
907
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
908
  const SparseFormat fmt = SelectFormat(CSC_CODE);
909
  const auto ptr = GetFormat(fmt);
910
  if (fmt == SparseFormat::kCSC)
911
912
913
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
914
915
}

Minjie Wang's avatar
Minjie Wang committed
916
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
917
  const SparseFormat fmt = SelectFormat(CSR_CODE);
918
919
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
920
921
}

Minjie Wang's avatar
Minjie Wang committed
922
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
923
  const SparseFormat fmt = SelectFormat(CSR_CODE);
924
  const auto ptr = GetFormat(fmt);
925
  if (fmt == SparseFormat::kCSC)
926
927
928
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
929
930
}

931
932
EdgeArray UnitGraph::EdgeIdsAll(
    dgl_type_t etype, IdArray src, IdArray dst) const {
933
  const SparseFormat fmt = SelectFormat(CSR_CODE);
934
  const auto ptr = GetFormat(fmt);
935
  if (fmt == SparseFormat::kCSC) {
936
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
937
938
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
939
940
941
942
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

943
944
IdArray UnitGraph::EdgeIdsOne(
    dgl_type_t etype, IdArray src, IdArray dst) const {
945
  const SparseFormat fmt = SelectFormat(CSR_CODE);
946
947
948
949
950
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
951
952
953
  }
}

954
955
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(
    dgl_type_t etype, dgl_id_t eid) const {
956
  const SparseFormat fmt = SelectFormat(COO_CODE);
957
958
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
959
960
}

Minjie Wang's avatar
Minjie Wang committed
961
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
962
  const SparseFormat fmt = SelectFormat(COO_CODE);
963
964
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
965
966
}

Minjie Wang's avatar
Minjie Wang committed
967
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
968
  const SparseFormat fmt = SelectFormat(CSC_CODE);
969
  const auto ptr = GetFormat(fmt);
970
  if (fmt == SparseFormat::kCSC) {
971
972
973
974
975
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
976
977
}

Minjie Wang's avatar
Minjie Wang committed
978
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
979
  const SparseFormat fmt = SelectFormat(CSC_CODE);
980
  const auto ptr = GetFormat(fmt);
981
  if (fmt == SparseFormat::kCSC) {
982
983
984
985
986
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
987
988
}

Minjie Wang's avatar
Minjie Wang committed
989
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
990
  const SparseFormat fmt = SelectFormat(CSR_CODE);
991
992
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
993
994
}

Minjie Wang's avatar
Minjie Wang committed
995
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
996
  const SparseFormat fmt = SelectFormat(CSR_CODE);
997
998
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
999
1000
}

1001
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {
1002
1003
  SparseFormat fmt;
  if (order == std::string("eid")) {
1004
    fmt = SelectFormat(COO_CODE);
1005
  } else if (order.empty()) {
1006
    // arbitrary order
1007
    fmt = SelectFormat(ALL_CODE);
1008
  } else if (order == std::string("srcdst")) {
1009
    fmt = SelectFormat(CSR_CODE);
1010
1011
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
1012
    return {};
1013
  }
1014
1015

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
1016
  if (fmt == SparseFormat::kCSC)
1017
1018
1019
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1020
1021
}

Minjie Wang's avatar
Minjie Wang committed
1022
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1023
  SparseFormat fmt = SelectFormat(CSC_CODE);
1024
  const auto ptr = GetFormat(fmt);
1025
1026
1027
1028
1029
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1030
1031
}

Minjie Wang's avatar
Minjie Wang committed
1032
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1033
  SparseFormat fmt = SelectFormat(CSC_CODE);
1034
  const auto ptr = GetFormat(fmt);
1035
1036
1037
1038
1039
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1040
1041
}

Minjie Wang's avatar
Minjie Wang committed
1042
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1043
  SparseFormat fmt = SelectFormat(CSR_CODE);
1044
  const auto ptr = GetFormat(fmt);
1045
1046
1047
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1048
  return ptr->OutDegree(etype, vid);
1049
1050
}

Minjie Wang's avatar
Minjie Wang committed
1051
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1052
  SparseFormat fmt = SelectFormat(CSR_CODE);
1053
  const auto ptr = GetFormat(fmt);
1054
1055
1056
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1057
  return ptr->OutDegrees(etype, vids);
1058
1059
}

Minjie Wang's avatar
Minjie Wang committed
1060
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1061
  SparseFormat fmt = SelectFormat(CSR_CODE);
1062
1063
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1064
1065
}

1066
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1067
  SparseFormat fmt = SelectFormat(CSR_CODE);
1068
1069
1070
1071
1072
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1073
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1074
  SparseFormat fmt = SelectFormat(CSR_CODE);
1075
1076
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1077
1078
}

Minjie Wang's avatar
Minjie Wang committed
1079
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1080
  SparseFormat fmt = SelectFormat(CSC_CODE);
1081
  const auto ptr = GetFormat(fmt);
1082
  if (fmt == SparseFormat::kCSC)
1083
1084
1085
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1086
1087
}

Minjie Wang's avatar
Minjie Wang committed
1088
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1089
  SparseFormat fmt = SelectFormat(CSC_CODE);
1090
  const auto ptr = GetFormat(fmt);
1091
  if (fmt == SparseFormat::kCSC)
1092
1093
1094
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1095
1096
}

Minjie Wang's avatar
Minjie Wang committed
1097
std::vector<IdArray> UnitGraph::GetAdj(
1098
1099
1100
1101
1102
1103
1104
1105
1106
    dgl_type_t etype, bool transpose, const std::string& fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst
  // nodes and col for src nodes. Therefore, we need to flip the transpose flag.
  // For example,
  //   transpose=False is equal to in edge CSR. We have this behavior because
  //   previously we use framework's SPMM and we don't cache reverse adj. This
  //   is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
  //   change the behavior and make row for src and col for dst.
1107
  if (fmt == std::string("csr")) {
1108
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1109
                      : GetInCSR()->GetAdj(etype, false, "csr");
1110
  } else if (fmt == std::string("coo")) {
1111
    return GetCOO()->GetAdj(etype, transpose, fmt);
1112
1113
1114
1115
1116
1117
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

1118
1119
HeteroSubgraph UnitGraph::VertexSubgraph(
    const std::vector<IdArray>& vids) const {
1120
  // We prefer to generate a subgraph from out-csr.
1121
  SparseFormat fmt = SelectFormat(CSR_CODE);
1122
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1123
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1143
1144
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1145
1146
1147
1148
1149
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1150
HeteroSubgraph UnitGraph::EdgeSubgraph(
1151
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1152
  SparseFormat fmt = SelectFormat(COO_CODE);
1153
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1154
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

1174
1175
  ret.graph =
      HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1176
1177
1178
1179
1180
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1181
HeteroGraphPtr UnitGraph::CreateFromCOO(
1182
1183
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
    IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1184
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1185
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1186
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1187
  COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));
1188

1189
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1190
1191
}

1192
HeteroGraphPtr UnitGraph::CreateFromCOO(
1193
    int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
1194
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1195
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1196
1197
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1198

1199
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
1200
1201
}

Minjie Wang's avatar
Minjie Wang committed
1202
HeteroGraphPtr UnitGraph::CreateFromCSR(
1203
1204
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1205
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1206
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
Minjie Wang's avatar
Minjie Wang committed
1207
1208
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1209
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1210
1211
}

1212
HeteroGraphPtr UnitGraph::CreateFromCSR(
1213
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1214
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1215
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1216
1217
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1218
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1219
1220
}

1221
HeteroGraphPtr UnitGraph::CreateFromCSC(
1222
1223
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
    IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1224
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1225
  if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
1226
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1227
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1228
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1229
1230
1231
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
1232
    int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
1233
  CHECK(num_vtypes == 1 || num_vtypes == 2);
1234
  if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
1235
1236
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1237
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1238
1239
}

Minjie Wang's avatar
Minjie Wang committed
1240
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1241
1242
1243
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1244
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1245
    CHECK_NOTNULL(bg);
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
    CSRPtr new_incsr = (bg->in_csr_->defined())
                           ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
                            ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
                         ? COOPtr(new COO(bg->coo_->AsNumBits(bits)))
                         : nullptr;
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1257
1258
1259
  }
}

1260
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
1261
1262
  if (ctx == g->Context()) {
    return g;
1263
1264
1265
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1266
    CSRPtr new_incsr = (bg->in_csr_->defined())
1267
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
1268
1269
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
1270
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
1271
1272
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
1273
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
1274
                         : nullptr;
1275
1276
    return HeteroGraphPtr(new UnitGraph(
        g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1277
1278
1279
  }
}

1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
HeteroGraphPtr UnitGraph::PinMemory() {
  CSRPtr pinned_in_csr, pinned_out_csr;
  COOPtr pinned_coo;
  if (this->in_csr_->defined() && this->in_csr_->IsPinned()) {
    pinned_in_csr = this->in_csr_;
  } else if (this->in_csr_->defined()) {
    pinned_in_csr = CSRPtr(new CSR(this->in_csr_->PinMemory()));
  } else {
    pinned_in_csr = nullptr;
  }

  if (this->out_csr_->defined() && this->out_csr_->IsPinned()) {
    pinned_out_csr = this->out_csr_;
  } else if (this->out_csr_->defined()) {
    pinned_out_csr = CSRPtr(new CSR(this->out_csr_->PinMemory()));
  } else {
    pinned_out_csr = nullptr;
  }

  if (this->coo_->defined() && this->coo_->IsPinned()) {
    pinned_coo = this->coo_;
  } else if (this->coo_->defined()) {
    pinned_coo = COOPtr(new COO(this->coo_->PinMemory()));
  } else {
    pinned_coo = nullptr;
  }

  return HeteroGraphPtr(new UnitGraph(
      meta_graph(), pinned_in_csr, pinned_out_csr, pinned_coo, this->formats_));
}

1311
void UnitGraph::PinMemory_() {
1312
1313
1314
  if (this->in_csr_->defined()) this->in_csr_->PinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->PinMemory_();
  if (this->coo_->defined()) this->coo_->PinMemory_();
1315
1316
1317
}

void UnitGraph::UnpinMemory_() {
1318
1319
1320
  if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();
  if (this->coo_->defined()) this->coo_->UnpinMemory_();
1321
1322
}

1323
void UnitGraph::RecordStream(DGLStreamHandle stream) {
1324
1325
1326
  if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);
  if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);
  if (this->coo_->defined()) this->coo_->RecordStream(stream);
1327
1328
1329
  this->recorded_streams.push_back(stream);
}

1330
void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }
1331

1332
void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }
1333

1334
void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }
1335

1336
1337
1338
1339
1340
1341
1342
UnitGraph::UnitGraph(
    GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
    dgl_format_code_t formats)
    : BaseHeteroGraph(metagraph),
      in_csr_(in_csr),
      out_csr_(out_csr),
      coo_(coo) {
1343
1344
1345
1346
1347
1348
1349
1350
1351
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1352
1353
1354
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
1355
1356
1357
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created)
               << ", which is not compatible with available formats: "
               << CodeToStr(formats);
1358
1359
1360
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1361
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
1362
1363
1364
    int num_vtypes, const aten::CSRMatrix& in_csr,
    const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,
    bool has_out_csr, bool has_coo, dgl_format_code_t formats) {
1365
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1366
1367
1368
1369
1370
1371
1372

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1373
1374
  else
    in_csr_ptr = CSRPtr(new CSR());
1375
1376
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1377
1378
  else
    out_csr_ptr = CSRPtr(new CSR());
1379
1380
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1381
1382
  else
    coo_ptr = COOPtr(new COO());
1383

1384
1385
  return HeteroGraphPtr(
      new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1386
1387
}

1388
1389
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1390
    if (!(formats_ & CSC_CODE))
1391
1392
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSC matrix.";
1393
  CSRPtr ret = in_csr_;
1394
1395
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1396
  if (!in_csr_->defined()) {
1397
    if (coo_->defined()) {
1398
      const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));
1399

1400
      if (inplace)
1401
1402
1403
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1404
    } else {
1405
1406
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1407

1408
      if (inplace)
1409
1410
1411
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1412
    }
1413
    if (inplace) {
1414
1415
      if (IsPinned()) in_csr_->PinMemory_();
      for (auto stream : recorded_streams) in_csr_->RecordStream(stream);
1416
    }
1417
  }
1418
  return ret;
1419
1420
}

1421
/** @brief Return out csr. If not exist, transpose the other one.*/
1422
1423
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1424
    if (!(formats_ & CSR_CODE))
1425
1426
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create CSR matrix.";
1427
  CSRPtr ret = out_csr_;
1428
1429
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1430
  if (!out_csr_->defined()) {
1431
1432
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1433

1434
      if (inplace)
1435
1436
1437
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1438
    } else {
1439
1440
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1441

1442
      if (inplace)
1443
1444
1445
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1446
    }
1447
    if (inplace) {
1448
1449
      if (IsPinned()) out_csr_->PinMemory_();
      for (auto stream : recorded_streams) out_csr_->RecordStream(stream);
1450
    }
1451
  }
1452
  return ret;
1453
1454
}

1455
/** @brief Return coo. If not exist, create from csr.*/
1456
1457
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1458
    if (!(formats_ & COO_CODE))
1459
1460
      LOG(FATAL) << "The graph have restricted sparse format "
                 << CodeToStr(formats_) << ", cannot create COO matrix.";
1461
  COOPtr ret = coo_;
1462
1463
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1464
1465
      const auto& newadj =
          aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1466

1467
      if (inplace)
1468
1469
1470
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1471
    } else {
1472
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1473
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1474

1475
      if (inplace)
1476
1477
1478
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1479
    }
1480
    if (inplace) {
1481
1482
      if (IsPinned()) coo_->PinMemory_();
      for (auto stream : recorded_streams) coo_->RecordStream(stream);
1483
    }
1484
  }
1485
  return ret;
1486
1487
}

1488
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1489
1490
1491
  return GetInCSR()->adj();
}

1492
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1493
1494
1495
  return GetOutCSR()->adj();
}

1496
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1497
1498
1499
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1500
HeteroGraphPtr UnitGraph::GetAny() const {
1501
  if (in_csr_->defined()) {
1502
    return in_csr_;
1503
  } else if (out_csr_->defined()) {
1504
1505
1506
1507
1508
1509
    return out_csr_;
  } else {
    return coo_;
  }
}

1510
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1511
  dgl_format_code_t ret = 0;
1512
1513
1514
  if (in_csr_->defined()) ret |= CSC_CODE;
  if (out_csr_->defined()) ret |= CSR_CODE;
  if (coo_->defined()) ret |= COO_CODE;
1515
1516
1517
  return ret;
}

1518
dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }
1519

1520
1521
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1522
1523
1524
1525
1526
1527
    case SparseFormat::kCSR:
      return GetOutCSR();
    case SparseFormat::kCSC:
      return GetInCSR();
    default:
      return GetCOO();
1528
1529
1530
  }
}

1531
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
  // Get the created formats.
  auto created_formats = GetCreatedFormats();
  // Get the intersection of formats and created_formats.
  auto intersection = formats & created_formats;

  // If the intersection of formats and created_formats is not empty.
  // The format(s) in the intersection will be retained.
  if (intersection != 0) {
    COOPtr coo_ptr = COO_CODE & intersection ? GetCOO(false) : nullptr;
    CSRPtr in_csr_ptr = CSC_CODE & intersection ? GetInCSR(false) : nullptr;
    CSRPtr out_csr_ptr = CSR_CODE & intersection ? GetOutCSR(false) : nullptr;

1544
    return HeteroGraphPtr(
1545
1546
1547
1548
1549
        new UnitGraph(meta_graph_, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
  }

  // If the intersection of formats and created_formats is empty.
  // Create a format in the order of COO -> CSR -> CSC.
1550
  int64_t num_vtypes = NumVertexTypes();
1551
  if (COO_CODE & formats)
1552
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1553
  if (CSR_CODE & formats)
1554
1555
1556
1557
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

1558
1559
SparseFormat UnitGraph::SelectFormat(
    dgl_format_code_t preferred_formats) const {
1560
1561
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
1562
  if (common & created) return DecodeFormat(common & created);
1563

1564
1565
1566
  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on
  // COO have not been implmented yet. if (coo_->defined() &&
  // coo_->IsHypersparse())  // only allow coo for hypersparse graph.
1567
  //   return SparseFormat::kCOO;
1568
  if (common) return DecodeFormat(common);
1569
  return DecodeFormat(created);
1570
1571
}

1572
1573
1574
1575
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1576
  if (in_csr_->defined()) {
1577
    aten::CSRMatrix csc = GetCSCMatrix(0);
1578
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1579
  }
1580
  if (out_csr_->defined()) {
1581
    aten::CSRMatrix csr = GetCSRMatrix(0);
1582
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1583
  }
1584
  if (coo_->defined()) {
1585
1586
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1587
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1588
1589
1590
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1591
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1592
1593
1594
1595
1596
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1597
1598
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1599
  auto fmt = SelectFormat(ALL_CODE);
1600
1601
  switch (fmt) {
    case SparseFormat::kCOO: {
1602
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1603
1604
1605
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
1606
1607
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1608
      return CreateFromCOO(1, coo);
1609
1610
1611
1612
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
1613
1614
      const aten::COOMatrix coo =
          aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1615
      return CreateFromCOO(1, coo);
1616
1617
1618
1619
1620
1621
1622
1623
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1624
1625
1626
1627
1628
1629
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1630

1631
  int64_t save_format_code, formats_code;
1632
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1633
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1634
1635
1636
1637
1638
1639
1640
1641
  dgl_format_code_t save_formats = ANY_CODE;
  if (save_format_code >> 32) {
    save_formats =
        static_cast<dgl_format_code_t>(0xffffffff & save_format_code);
  } else {
    save_formats =
        SparseFormatsToCode({static_cast<SparseFormat>(save_format_code)});
  }
1642
1643
1644
1645
1646
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
      case 0:
        formats_ = ALL_CODE;
        break;
      case 1:
        formats_ = COO_CODE;
        break;
      case 2:
        formats_ = CSR_CODE;
        break;
      case 3:
        formats_ = CSC_CODE;
        break;
      default:
        LOG(FATAL) << "Load graph failed, formats code " << formats_code
                   << "not recognized.";
1662
1663
    }
  }
1664

1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
  if (save_formats & COO_CODE) {
    fs->Read(&coo_);
  }
  if (save_formats & CSR_CODE) {
    fs->Read(&out_csr_);
  }
  if (save_formats & CSC_CODE) {
    fs->Read(&in_csr_);
  }
  if (!coo_ && !out_csr_ && !in_csr_) {
    LOG(FATAL) << "unsupported format code";
1676
1677
  }

1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1688
1689
  meta_graph_ = GetAny()->meta_graph();

1690
1691
1692
1693
1694
  return true;
}

void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1695
1696
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1697
  auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
1698
  auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);
1699
1700
1701
1702
1703
1704
1705
  if (fstream) {
    auto formats = fstream->FormatsToSave();
    save_formats = formats == ANY_CODE
                       ? SparseFormatsToCode({SelectFormat(ALL_CODE)})
                       : formats;
  }
  fs->Write(static_cast<int64_t>(save_formats | 0x100000000));
1706
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1707
1708
1709
1710
1711
1712
1713
1714
  if (save_formats & COO_CODE) {
    fs->Write(GetCOO());
  }
  if (save_formats & CSR_CODE) {
    fs->Write(GetOutCSR());
  }
  if (save_formats & CSC_CODE) {
    fs->Write(GetInCSR());
1715
  }
1716
1717
}

1718
1719
1720
1721
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
1722
1723
    new_coo =
        COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
1724
1725
  }

1726
1727
  return UnitGraphPtr(
      new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
1728
1729
}

1730
std::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {
1731
1732
1733
1734
1735
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1736
  auto avail_fmt = SelectFormat(ALL_CODE);
1737
1738
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1739
      auto ret = aten::COOToSimple(GetCOO()->adj());
1740
1741
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1742
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1743
1744
1745
      break;
    }
    case SparseFormat::kCSR: {
1746
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1747
1748
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1749
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1750
1751
1752
      break;
    }
    case SparseFormat::kCSC: {
1753
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1754
1755
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1756
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1757
1758
1759
1760
1761
1762
1763
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

1764
1765
1766
  return std::make_tuple(
      UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
      count, edge_map);
1767
1768
}

1769
}  // namespace dgl