unit_graph.cc 56.9 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
      IdArray dst, bool row_sorted = false, bool col_sorted = false)
Minjie Wang's avatar
Minjie Wang committed
66
    : BaseHeteroGraph(metagraph) {
67
68
69
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
70
71
72
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
        NullArray(),
        row_sorted, col_sorted};
73
  }
74

75
76
77
78
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
79
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
80
    adj_.data = aten::NullArray();
81
  }
82

83
84
85
86
87
88
89
90
91
92
93
  COO() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
94
95
  inline dgl_type_t SrcType() const {
    return 0;
96
  }
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
104
105
106
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
108
109
110
111
112
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
113
    LOG(FATAL) << "UnitGraph graph is not mutable.";
114
115
116
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
117
    LOG(FATAL) << "UnitGraph graph is not mutable.";
118
119
120
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
121
    LOG(FATAL) << "UnitGraph graph is not mutable.";
122
123
124
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
125
    LOG(FATAL) << "UnitGraph graph is not mutable.";
126
127
  }

128
129
130
131
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

132
133
134
135
  DLContext Context() const override {
    return adj_.row->ctx;
  }

136
  bool IsPinned() const override {
137
    return adj_.is_pinned;
138
139
  }

140
141
142
143
  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

144
145
146
147
148
149
150
151
152
153
154
155
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

156
  COO CopyTo(const DLContext &ctx) const {
157
158
    if (Context() == ctx)
      return *this;
159
    return COO(meta_graph_, adj_.CopyTo(ctx));
160
161
  }

162
163
164
165
166
167
168
169
170
171
172

  /*! \brief Pin the adj_: COOMatrix of the COO graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: COOMatrix of the COO graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

173
174
175
176
177
  /*! \brief Record stream for the adj_: COOMatrix of the COO graph. */
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

178
  bool IsMultigraph() const override {
179
    return aten::COOHasDuplicate(adj_);
180
181
182
183
184
185
186
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
187
    if (vtype == SrcType()) {
188
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
189
    } else if (vtype == DstType()) {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
211
212
213
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
214
215
216
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
217
218
219
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
220
221
222
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
223
224
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
225
226
227
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
228
229
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
230
231
232
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
233
234
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
235
    return aten::COOGetAllData(adj_, src, dst);
236
237
  }

238
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
239
240
241
242
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
243
244
  }

245
246
247
248
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::COOGetData(adj_, src, dst);
  }

249
250
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
251
252
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
253
254
255
256
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
257
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
258
    BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
259
      "FindEdges requires the internal COO matrix not having EIDs.";
260
261
262
263
264
265
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
266
267
268
269
270
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
271
272
273
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
274
275
276
277
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
278
279
280
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
281
282
283
284
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
285
286
287
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
288
289
290
291
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
292
293
294
295
296
297
298
299
300
301
302
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
303
304
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
305
306
307
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
308
309
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
310
311
312
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
313
314
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
315
316
317
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
318
319
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

352
353
354
355
356
357
358
359
360
361
362
363
364
365
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

366
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
367
    LOG(FATAL) << "Not enabled for COO graph";
368
    return SparseFormat::kCOO;
369
370
  }

371
  dgl_format_code_t GetAllowedFormats() const override {
372
    LOG(FATAL) << "Not enabled for COO graph";
373
    return 0;
374
375
  }

376
  dgl_format_code_t GetCreatedFormats() const override {
377
378
379
380
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
  }

381
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
382
383
384
385
386
387
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
388
389
    DLContext ctx = aten::GetContextOf(vids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
390
391
392
393
394
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
395
396
397
398
399
400
401
402
403
404
405
406
407
408
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
409
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
410
411
412
413
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
414
      subg.induced_vertices.emplace_back(
415
          aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
416
      subg.induced_vertices.emplace_back(
417
          aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
418
      subg.graph = std::make_shared<COO>(
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
419
          meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
420
421
422
423
424
      subg.induced_edges = eids;
    }
    return subg;
  }

425
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
426
427
428
429
    LOG(FATAL) << "Not enabled for COO graph.";
    return nullptr;
  }

430
431
432
433
  aten::COOMatrix adj() const {
    return adj_;
  }

434
435
436
437
438
439
440
441
442
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

443
444
445
446
447
448
449
450
451
452
453
454
455
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

456
 private:
457
458
  friend class Serializer;

459
460
461
462
463
464
465
466
467
468
469
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
470
class UnitGraph::CSR : public BaseHeteroGraph {
471
 public:
Minjie Wang's avatar
Minjie Wang committed
472
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
473
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
474
    : BaseHeteroGraph(metagraph) {
475
476
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
477
478
479
    if (aten::IsValidIdArray(edge_ids))
      CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
        << "edge id arrays should have the same length as indices if not empty";
480
481
    CHECK_EQ(num_src, indptr->shape[0] - 1)
      << "number of nodes do not match the length of indptr minus 1.";
482

483
484
485
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

486
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
487
488
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
489

490
491
492
493
494
495
496
497
498
499
500
  CSR() {
    // set magic num_rows/num_cols to mark it as undefined
    // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
    adj_.num_rows = -1;
    adj_.num_cols = -1;
  };

  bool defined() const {
    return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
  }

Minjie Wang's avatar
Minjie Wang committed
501
502
  inline dgl_type_t SrcType() const {
    return 0;
503
  }
Minjie Wang's avatar
Minjie Wang committed
504
505
506
507
508
509
510

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
511
512
513
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
514
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
515
516
517
518
519
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
520
    LOG(FATAL) << "UnitGraph graph is not mutable.";
521
522
523
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
524
    LOG(FATAL) << "UnitGraph graph is not mutable.";
525
526
527
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
528
    LOG(FATAL) << "UnitGraph graph is not mutable.";
529
530
531
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
532
    LOG(FATAL) << "UnitGraph graph is not mutable.";
533
534
  }

535
536
537
538
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

539
540
541
542
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

543
  bool IsPinned() const override {
544
    return adj_.is_pinned;
545
546
  }

547
548
549
550
  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

551
552
553
554
555
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
556
          meta_graph_,
557
558
559
560
561
562
563
564
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

565
  CSR CopyTo(const DLContext &ctx) const {
566
567
568
    if (Context() == ctx) {
      return *this;
    } else {
569
      return CSR(meta_graph_, adj_.CopyTo(ctx));
570
571
572
    }
  }

573
574
575
576
577
578
579
580
581
582
  /*! \brief Pin the adj_: CSRMatrix of the CSR graph. */
  void PinMemory_() {
    adj_.PinMemory_();
  }

  /*! \brief Unpin the adj_: CSRMatrix of the CSR graph. */
  void UnpinMemory_() {
    adj_.UnpinMemory_();
  }

583
584
585
586
587
  /*! \brief Record stream for the adj_: CSRMatrix of the CSR graph. */
  void RecordStream(DGLStreamHandle stream) override {
    adj_.RecordStream(stream);
  }

588
  bool IsMultigraph() const override {
589
    return aten::CSRHasDuplicate(adj_);
590
591
592
593
594
595
596
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
597
    if (vtype == SrcType()) {
598
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
599
    } else if (vtype == DstType()) {
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
621
622
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
623
624
625
626
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
627
628
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
629
630
631
632
633
634
635
636
637
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
638
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
639
640
641
642
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
643
644
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
645
    return aten::CSRGetAllData(adj_, src, dst);
646
647
  }

648
  EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
649
650
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
651
652
653
654
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

655
656
657
658
  IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
    return aten::CSRGetData(adj_, src, dst);
  }

659
  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
660
    LOG(FATAL) << "Not enabled for CSR graph.";
661
662
663
664
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
665
    LOG(FATAL) << "Not enabled for CSR graph.";
666
667
668
669
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
670
    LOG(FATAL) << "Not enabled for CSR graph.";
671
672
673
674
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
675
    LOG(FATAL) << "Not enabled for CSR graph.";
676
677
678
679
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
680
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
681
682
683
684
685
686
687
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
688
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
689
690
691
692
693
694
695
696
697
698
699
700
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
701
702
703
704
705
    auto coo = aten::CSRToCOO(adj_, false);
    if (order == std::string("srcdst")) {
      // make sure the coo is sorted if an order is requested
      coo = aten::COOSort(coo, true);
    }
706
707
708
709
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
710
    LOG(FATAL) << "Not enabled for CSR graph.";
711
712
713
714
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
715
    LOG(FATAL) << "Not enabled for CSR graph.";
716
717
718
719
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
720
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
721
722
723
724
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
725
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
726
727
728
729
730
731
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
732
    CHECK_EQ(NumBits(), 64);
733
734
735
736
737
738
739
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

740
741
742
743
744
745
746
747
748
749
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

750
751
752
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
753
    CHECK_EQ(NumBits(), 64);
754
755
756
757
758
759
760
761
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
762
    LOG(FATAL) << "Not enabled for CSR graph.";
763
764
765
766
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
767
    LOG(FATAL) << "Not enabled for CSR graph.";
768
769
770
771
772
773
774
775
776
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

777
778
779
780
781
782
783
784
785
786
787
788
789
790
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

791
  SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
792
    LOG(FATAL) << "Not enabled for CSR graph";
793
    return SparseFormat::kCSR;
794
795
  }

796
797
798
  dgl_format_code_t GetAllowedFormats() const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return 0;
799
800
  }

801
  dgl_format_code_t GetCreatedFormats() const override {
802
803
804
805
    LOG(FATAL) << "Not enabled for CSR graph";
    return 0;
  }

806
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
807
808
809
810
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
811
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
812
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
813
814
    DLContext ctx = aten::GetContextOf(vids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
Minjie Wang's avatar
Minjie Wang committed
815
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
816
817
818
819
820
821
822
823
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
824
    LOG(FATAL) << "Not enabled for CSR graph.";
825
826
827
    return {};
  }

828
  HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
829
830
831
832
    LOG(FATAL) << "Not enabled for CSR graph.";
    return nullptr;
  }

833
834
835
836
  aten::CSRMatrix adj() const {
    return adj_;
  }

837
838
839
840
841
842
843
844
845
846
847
848
849
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

850
 private:
851
852
  friend class Serializer;

853
854
855
856
857
858
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
859
// unit graph implementation
860
861
862
//
//////////////////////////////////////////////////////////

863
864
865
866
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
867
DLContext UnitGraph::Context() const {
868
869
870
  return GetAny()->Context();
}

871
872
873
874
bool UnitGraph::IsPinned() const {
  return GetAny()->IsPinned();
}

Minjie Wang's avatar
Minjie Wang committed
875
uint8_t UnitGraph::NumBits() const {
876
877
878
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
879
bool UnitGraph::IsMultigraph() const {
880
  const SparseFormat fmt = SelectFormat(CSC_CODE);
881
882
  const auto ptr = GetFormat(fmt);
  return ptr->IsMultigraph();
883
884
}

Minjie Wang's avatar
Minjie Wang committed
885
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
886
  const SparseFormat fmt = SelectFormat(ALL_CODE);
887
888
889
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
890
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
891
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
892
  return ptr->NumVertices(vtype);
893
894
}

Minjie Wang's avatar
Minjie Wang committed
895
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
896
897
898
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
899
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
900
  const SparseFormat fmt = SelectFormat(ALL_CODE);
901
  const auto ptr = GetFormat(fmt);
902
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
903
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
904
  return ptr->HasVertex(vtype, vid);
905
906
}

Minjie Wang's avatar
Minjie Wang committed
907
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
908
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
909
910
911
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
912
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
913
  const SparseFormat fmt = SelectFormat(CSC_CODE);
914
  const auto ptr = GetFormat(fmt);
915
  if (fmt == SparseFormat::kCSC)
916
917
918
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
919
920
}

Minjie Wang's avatar
Minjie Wang committed
921
BoolArray UnitGraph::HasEdgesBetween(
922
    dgl_type_t etype, IdArray src, IdArray dst) const {
923
  const SparseFormat fmt = SelectFormat(CSC_CODE);
924
  const auto ptr = GetFormat(fmt);
925
  if (fmt == SparseFormat::kCSC)
926
927
928
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
929
930
}

Minjie Wang's avatar
Minjie Wang committed
931
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
932
  const SparseFormat fmt = SelectFormat(CSC_CODE);
933
  const auto ptr = GetFormat(fmt);
934
  if (fmt == SparseFormat::kCSC)
935
936
937
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
938
939
}

Minjie Wang's avatar
Minjie Wang committed
940
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
941
  const SparseFormat fmt = SelectFormat(CSR_CODE);
942
943
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
944
945
}

Minjie Wang's avatar
Minjie Wang committed
946
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
947
  const SparseFormat fmt = SelectFormat(CSR_CODE);
948
  const auto ptr = GetFormat(fmt);
949
  if (fmt == SparseFormat::kCSC)
950
951
952
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
953
954
}

955
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
956
  const SparseFormat fmt = SelectFormat(CSR_CODE);
957
  const auto ptr = GetFormat(fmt);
958
  if (fmt == SparseFormat::kCSC) {
959
    EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
960
961
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
962
963
964
965
966
    return ptr->EdgeIdsAll(etype, src, dst);
  }
}

IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
967
  const SparseFormat fmt = SelectFormat(CSR_CODE);
968
969
970
971
972
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::kCSC) {
    return ptr->EdgeIdsOne(etype, dst, src);
  } else {
    return ptr->EdgeIdsOne(etype, src, dst);
973
974
975
  }
}

Minjie Wang's avatar
Minjie Wang committed
976
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
977
  const SparseFormat fmt = SelectFormat(COO_CODE);
978
979
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
980
981
}

Minjie Wang's avatar
Minjie Wang committed
982
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
983
  const SparseFormat fmt = SelectFormat(COO_CODE);
984
985
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
986
987
}

Minjie Wang's avatar
Minjie Wang committed
988
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
989
  const SparseFormat fmt = SelectFormat(CSC_CODE);
990
  const auto ptr = GetFormat(fmt);
991
  if (fmt == SparseFormat::kCSC) {
992
993
994
995
996
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
997
998
}

Minjie Wang's avatar
Minjie Wang committed
999
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
1000
  const SparseFormat fmt = SelectFormat(CSC_CODE);
1001
  const auto ptr = GetFormat(fmt);
1002
  if (fmt == SparseFormat::kCSC) {
1003
1004
1005
1006
1007
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
1008
1009
}

Minjie Wang's avatar
Minjie Wang committed
1010
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
1011
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1012
1013
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
1014
1015
}

Minjie Wang's avatar
Minjie Wang committed
1016
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
1017
  const SparseFormat fmt = SelectFormat(CSR_CODE);
1018
1019
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
1020
1021
}

Minjie Wang's avatar
Minjie Wang committed
1022
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
1023
1024
  SparseFormat fmt;
  if (order == std::string("eid")) {
1025
    fmt = SelectFormat(COO_CODE);
1026
  } else if (order.empty()) {
1027
    // arbitrary order
1028
    fmt = SelectFormat(ALL_CODE);
1029
  } else if (order == std::string("srcdst")) {
1030
    fmt = SelectFormat(CSR_CODE);
1031
1032
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
1033
    return {};
1034
  }
1035
1036

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
1037
  if (fmt == SparseFormat::kCSC)
1038
1039
1040
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
1041
1042
}

Minjie Wang's avatar
Minjie Wang committed
1043
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
1044
  SparseFormat fmt = SelectFormat(CSC_CODE);
1045
  const auto ptr = GetFormat(fmt);
1046
1047
1048
1049
1050
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid)
                                   : ptr->InDegree(etype, vid);
1051
1052
}

Minjie Wang's avatar
Minjie Wang committed
1053
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
1054
  SparseFormat fmt = SelectFormat(CSC_CODE);
1055
  const auto ptr = GetFormat(fmt);
1056
1057
1058
1059
1060
  CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO)
      << "In degree cannot be computed as neither CSC nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
  return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids)
                                   : ptr->InDegrees(etype, vids);
1061
1062
}

Minjie Wang's avatar
Minjie Wang committed
1063
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
1064
  SparseFormat fmt = SelectFormat(CSR_CODE);
1065
  const auto ptr = GetFormat(fmt);
1066
1067
1068
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1069
  return ptr->OutDegree(etype, vid);
1070
1071
}

Minjie Wang's avatar
Minjie Wang committed
1072
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
1073
  SparseFormat fmt = SelectFormat(CSR_CODE);
1074
  const auto ptr = GetFormat(fmt);
1075
1076
1077
  CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO)
      << "Out degree cannot be computed as neither CSR nor COO format is "
         "allowed for this graph. Please enable one of them at least.";
1078
  return ptr->OutDegrees(etype, vids);
1079
1080
}

Minjie Wang's avatar
Minjie Wang committed
1081
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
1082
  SparseFormat fmt = SelectFormat(CSR_CODE);
1083
1084
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
1085
1086
}

1087
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
1088
  SparseFormat fmt = SelectFormat(CSR_CODE);
1089
1090
1091
1092
1093
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
1094
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1095
  SparseFormat fmt = SelectFormat(CSR_CODE);
1096
1097
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
1098
1099
}

Minjie Wang's avatar
Minjie Wang committed
1100
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
1101
  SparseFormat fmt = SelectFormat(CSC_CODE);
1102
  const auto ptr = GetFormat(fmt);
1103
  if (fmt == SparseFormat::kCSC)
1104
1105
1106
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
1107
1108
}

Minjie Wang's avatar
Minjie Wang committed
1109
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
1110
  SparseFormat fmt = SelectFormat(CSC_CODE);
1111
  const auto ptr = GetFormat(fmt);
1112
  if (fmt == SparseFormat::kCSC)
1113
1114
1115
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
1116
1117
}

Minjie Wang's avatar
Minjie Wang committed
1118
std::vector<IdArray> UnitGraph::GetAdj(
1119
1120
1121
1122
1123
1124
1125
1126
1127
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
1128
    return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
1129
1130
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
1131
    return GetCOO()->GetAdj(etype, transpose, fmt);
1132
1133
1134
1135
1136
1137
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
1138
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
1139
  // We prefer to generate a subgraph from out-csr.
1140
  SparseFormat fmt = SelectFormat(CSR_CODE);
1141
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1142
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1163
1164
1165
1166
1167
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1168
HeteroSubgraph UnitGraph::EdgeSubgraph(
1169
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1170
  SparseFormat fmt = SelectFormat(COO_CODE);
1171
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1172
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1193
1194
1195
1196
1197
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1198
HeteroGraphPtr UnitGraph::CreateFromCOO(
1199
1200
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1201
    bool row_sorted, bool col_sorted,
1202
    dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1203
1204
1205
1206
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1207
1208
  COOPtr coo(new COO(mg, num_src, num_dst, row, col,
      row_sorted, col_sorted));
1209
1210

  return HeteroGraphPtr(
1211
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1212
1213
}

1214
1215
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
1216
    dgl_format_code_t formats) {
1217
1218
1219
1220
1221
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
1222

1223
  return HeteroGraphPtr(
1224
      new UnitGraph(mg, nullptr, nullptr, coo, formats));
1225
1226
}

Minjie Wang's avatar
Minjie Wang committed
1227
1228
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1229
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
Minjie Wang's avatar
Minjie Wang committed
1230
1231
1232
1233
1234
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1235
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1236
1237
}

1238
1239
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1240
    dgl_format_code_t formats) {
1241
1242
1243
1244
1245
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
1246
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
1247
1248
}

1249
1250
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1251
    IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
1252
1253
1254
1255
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1256
  CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
1257
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1258
1259
1260
1261
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
1262
    dgl_format_code_t formats) {
1263
1264
1265
1266
1267
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
1268
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
1269
1270
}

Minjie Wang's avatar
Minjie Wang committed
1271
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1272
1273
1274
  if (g->NumBits() == bits) {
    return g;
  } else {
Minjie Wang's avatar
Minjie Wang committed
1275
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1276
    CHECK_NOTNULL(bg);
1277
1278
1279
1280
1281
1282
    CSRPtr new_incsr =
      (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
    CSRPtr new_outcsr =
      (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
    COOPtr new_coo =
      (bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
1283
    return HeteroGraphPtr(
1284
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1285
1286
1287
  }
}

1288
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx) {
1289
1290
  if (ctx == g->Context()) {
    return g;
1291
1292
1293
  } else {
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
    CHECK_NOTNULL(bg);
1294
    CSRPtr new_incsr = (bg->in_csr_->defined())
1295
                           ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx)))
1296
1297
                           : nullptr;
    CSRPtr new_outcsr = (bg->out_csr_->defined())
1298
                            ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx)))
1299
1300
                            : nullptr;
    COOPtr new_coo = (bg->coo_->defined())
1301
                         ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
1302
                         : nullptr;
1303
    return HeteroGraphPtr(
1304
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
1305
1306
1307
  }
}

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
void UnitGraph::PinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->PinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->PinMemory_();
  if (this->coo_->defined())
    this->coo_->PinMemory_();
}

void UnitGraph::UnpinMemory_() {
  if (this->in_csr_->defined())
    this->in_csr_->UnpinMemory_();
  if (this->out_csr_->defined())
    this->out_csr_->UnpinMemory_();
  if (this->coo_->defined())
    this->coo_->UnpinMemory_();
}

1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
void UnitGraph::RecordStream(DGLStreamHandle stream) {
  if (this->in_csr_->defined())
    this->in_csr_->RecordStream(stream);
  if (this->out_csr_->defined())
    this->out_csr_->RecordStream(stream);
  if (this->coo_->defined())
    this->coo_->RecordStream(stream);
  this->recorded_streams.push_back(stream);
}

1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
void UnitGraph::InvalidateCSR() {
  this->out_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCSC() {
  this->in_csr_ = CSRPtr(new CSR());
}

void UnitGraph::InvalidateCOO() {
  this->coo_ = COOPtr(new COO());
}

1348
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
1349
                     dgl_format_code_t formats)
Minjie Wang's avatar
Minjie Wang committed
1350
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1351
1352
1353
1354
1355
1356
1357
1358
1359
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }
1360
1361
1362
1363
1364
  formats_ = formats;
  dgl_format_code_t created = GetCreatedFormats();
  if ((formats | created) != formats)
    LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) <<
      ", which is not compatible with available formats: " << CodeToStr(formats);
1365
1366
1367
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1368
1369
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
    int num_vtypes,
1370
1371
1372
1373
1374
1375
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
1376
    dgl_format_code_t formats) {
1377
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1378
1379
1380
1381
1382
1383
1384

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
1385
1386
  else
    in_csr_ptr = CSRPtr(new CSR());
1387
1388
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
1389
1390
  else
    out_csr_ptr = CSRPtr(new CSR());
1391
1392
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));
1393
1394
  else
    coo_ptr = COOPtr(new COO());
1395

1396
  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
1397
1398
}

1399
1400
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
  if (inplace)
1401
    if (!(formats_ & CSC_CODE))
1402
1403
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSC matrix.";
1404
  CSRPtr ret = in_csr_;
1405
1406
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1407
  if (!in_csr_->defined()) {
1408
1409
1410
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(
            aten::COOTranspose(coo_->adj()));
1411

1412
      if (inplace)
1413
1414
1415
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1416
    } else {
1417
1418
      CHECK(out_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
1419

1420
      if (inplace)
1421
1422
1423
        *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1424
    }
1425
1426
1427
1428
1429
1430
    if (inplace) {
      if (IsPinned())
        in_csr_->PinMemory_();
      for (auto stream : recorded_streams)
        in_csr_->RecordStream(stream);
    }
1431
  }
1432
  return ret;
1433
1434
1435
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
1436
1437
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
  if (inplace)
1438
    if (!(formats_ & CSR_CODE))
1439
1440
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create CSR matrix.";
1441
  CSRPtr ret = out_csr_;
1442
1443
  // Prefers converting from COO since it is parallelized.
  // TODO(BarclayII): need benchmarking.
1444
  if (!out_csr_->defined()) {
1445
1446
    if (coo_->defined()) {
      const auto& newadj = aten::COOToCSR(coo_->adj());
1447

1448
      if (inplace)
1449
1450
1451
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1452
    } else {
1453
1454
      CHECK(in_csr_->defined()) << "None of CSR, COO exist";
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
1455

1456
      if (inplace)
1457
1458
1459
        *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
      else
        ret = std::make_shared<CSR>(meta_graph(), newadj);
1460
    }
1461
1462
1463
1464
1465
1466
    if (inplace) {
      if (IsPinned())
        out_csr_->PinMemory_();
      for (auto stream : recorded_streams)
        out_csr_->RecordStream(stream);
    }
1467
  }
1468
  return ret;
1469
1470
1471
}

/* !\brief Return coo. If not exist, create from csr.*/
1472
1473
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
  if (inplace)
1474
    if (!(formats_ & COO_CODE))
1475
1476
      LOG(FATAL) << "The graph have restricted sparse format " <<
        CodeToStr(formats_) << ", cannot create COO matrix.";
1477
  COOPtr ret = coo_;
1478
1479
  if (!coo_->defined()) {
    if (in_csr_->defined()) {
1480
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
1481

1482
      if (inplace)
1483
1484
1485
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1486
    } else {
1487
      CHECK(out_csr_->defined()) << "Both CSR are missing.";
1488
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
1489

1490
      if (inplace)
1491
1492
1493
        *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
      else
        ret = std::make_shared<COO>(meta_graph(), newadj);
1494
    }
1495
1496
1497
1498
1499
1500
    if (inplace) {
      if (IsPinned())
        coo_->PinMemory_();
      for (auto stream : recorded_streams)
        coo_->RecordStream(stream);
    }
1501
  }
1502
  return ret;
1503
1504
}

1505
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1506
1507
1508
  return GetInCSR()->adj();
}

1509
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1510
1511
1512
  return GetOutCSR()->adj();
}

1513
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1514
1515
1516
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1517
HeteroGraphPtr UnitGraph::GetAny() const {
1518
  if (in_csr_->defined()) {
1519
    return in_csr_;
1520
  } else if (out_csr_->defined()) {
1521
1522
1523
1524
1525
1526
    return out_csr_;
  } else {
    return coo_;
  }
}

1527
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
1528
  dgl_format_code_t ret = 0;
1529
  if (in_csr_->defined())
1530
    ret |= CSC_CODE;
1531
  if (out_csr_->defined())
1532
    ret |= CSR_CODE;
1533
  if (coo_->defined())
1534
    ret |= COO_CODE;
1535
1536
1537
  return ret;
}

1538
1539
1540
1541
dgl_format_code_t UnitGraph::GetAllowedFormats() const {
  return formats_;
}

1542
1543
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1544
1545
1546
1547
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
1548
  default:
1549
    return GetCOO();
1550
1551
1552
  }
}

1553
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
1554
  if (formats == ALL_CODE)
1555
    return HeteroGraphPtr(
1556
1557
        // TODO(xiangsx) Make it as graph storage.Clone()
        new UnitGraph(meta_graph_,
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
                      (in_csr_->defined())
                          ? CSRPtr(new CSR(*in_csr_))
                          : nullptr,
                      (out_csr_->defined())
                          ? CSRPtr(new CSR(*out_csr_))
                          : nullptr,
                      (coo_->defined())
                          ? COOPtr(new COO(*coo_))
                          : nullptr,
                      formats));
  int64_t num_vtypes = NumVertexTypes();
1569
  if (formats & COO_CODE)
1570
    return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
1571
  if (formats & CSR_CODE)
1572
1573
1574
1575
1576
1577
1578
1579
1580
    return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
  return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}

SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const {
  dgl_format_code_t common = preferred_formats & formats_;
  dgl_format_code_t created = GetCreatedFormats();
  if (common & created)
    return DecodeFormat(common & created);
1581
1582
1583
1584
1585

  // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have
  // not been implmented yet.
  // if (coo_->defined() && coo_->IsHypersparse())  // only allow coo for hypersparse graph.
  //   return SparseFormat::kCOO;
1586
1587
1588
  if (common)
    return DecodeFormat(common);
  return DecodeFormat(created);
1589
1590
}

1591
1592
1593
1594
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
1595
  if (in_csr_->defined()) {
1596
    aten::CSRMatrix csc = GetCSCMatrix(0);
1597
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1598
  }
1599
  if (out_csr_->defined()) {
1600
    aten::CSRMatrix csr = GetCSRMatrix(0);
1601
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1602
  }
1603
  if (coo_->defined()) {
1604
1605
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1606
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1607
1608
1609
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1610
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1611
1612
1613
1614
1615
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1616
1617
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
  // TODO(xiangsx) currently we only support homogeneous graph
1618
  auto fmt = SelectFormat(ALL_CODE);
1619
1620
  switch (fmt) {
    case SparseFormat::kCOO: {
1621
      return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
1622
1623
1624
1625
    }
    case SparseFormat::kCSR: {
      const aten::CSRMatrix csr = GetCSRMatrix(0);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1626
      return CreateFromCOO(1, coo);
1627
1628
1629
1630
1631
    }
    case SparseFormat::kCSC: {
      const aten::CSRMatrix csc = GetCSCMatrix(0);
      const aten::CSRMatrix csr = aten::CSRTranspose(csc);
      const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
1632
      return CreateFromCOO(1, coo);
1633
1634
1635
1636
1637
1638
1639
1640
    }
    default:
      LOG(FATAL) << "None of CSC, CSR, COO exist";
      break;
  }
  return nullptr;
}

1641
1642
1643
1644
1645
1646
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1647

1648
  int64_t save_format_code, formats_code;
1649
  CHECK(fs->Read(&save_format_code)) << "Invalid format";
1650
  CHECK(fs->Read(&formats_code)) << "Invalid format";
1651
  auto save_format = static_cast<SparseFormat>(save_format_code);
1652
1653
1654
1655
1656
1657
  if (formats_code >> 32) {
    formats_ = static_cast<dgl_format_code_t>(0xffffffff & formats_code);
  } else {
    // NOTE(zihao): to be compatible with old formats.
    switch (formats_code & 0xffffffff) {
    case 0:
1658
      formats_ = ALL_CODE;
1659
1660
      break;
    case 1:
1661
      formats_ = COO_CODE;
1662
1663
      break;
    case 2:
1664
      formats_ = CSR_CODE;
1665
1666
      break;
    case 3:
1667
      formats_ = CSC_CODE;
1668
1669
1670
1671
1672
1673
      break;
    default:
      LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
        "not recognized.";
    }
  }
1674

1675
  switch (save_format) {
1676
    case SparseFormat::kCOO:
1677
1678
      fs->Read(&coo_);
      break;
1679
    case SparseFormat::kCSR:
1680
1681
      fs->Read(&out_csr_);
      break;
1682
    case SparseFormat::kCSC:
1683
1684
1685
1686
1687
1688
1689
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
  if (!in_csr_) {
    in_csr_ = CSRPtr(new CSR());
  }
  if (!out_csr_) {
    out_csr_ = CSRPtr(new CSR());
  }
  if (!coo_) {
    coo_ = COOPtr(new COO());
  }

1700
1701
  meta_graph_ = GetAny()->meta_graph();

1702
1703
1704
  return true;
}

1705

1706
1707
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1708
1709
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1710
  auto avail_fmt = SelectFormat(ALL_CODE);
1711
  fs->Write(static_cast<int64_t>(avail_fmt));
1712
  fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
1713
  switch (avail_fmt) {
1714
    case SparseFormat::kCOO:
1715
1716
      fs->Write(GetCOO());
      break;
1717
    case SparseFormat::kCSR:
1718
1719
      fs->Write(GetOutCSR());
      break;
1720
    case SparseFormat::kCSC:
1721
1722
1723
1724
1725
1726
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1727
1728
}

1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
UnitGraphPtr UnitGraph::Reverse() const {
  CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
  COOPtr new_coo = nullptr;
  if (coo_->defined()) {
    new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
  }

  return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}

1739
1740
1741
1742
1743
1744
1745
std::tuple<UnitGraphPtr, IdArray, IdArray>
UnitGraph::ToSimple() const {
  CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
  COOPtr new_coo = nullptr;
  IdArray count;
  IdArray edge_map;

1746
  auto avail_fmt = SelectFormat(ALL_CODE);
1747
1748
  switch (avail_fmt) {
    case SparseFormat::kCOO: {
1749
      auto ret = aten::COOToSimple(GetCOO()->adj());
1750
1751
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1752
      new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret)));
1753
1754
1755
      break;
    }
    case SparseFormat::kCSR: {
1756
      auto ret = aten::CSRToSimple(GetOutCSR()->adj());
1757
1758
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1759
      new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1760
1761
1762
      break;
    }
    case SparseFormat::kCSC: {
1763
      auto ret = aten::CSRToSimple(GetInCSR()->adj());
1764
1765
      count = std::get<1>(ret);
      edge_map = std::get<2>(ret);
1766
      new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret)));
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
      break;
    }
    default:
      LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist.";
      break;
  }

  return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
                         count,
                         edge_map);
}

1779
}  // namespace dgl