unit_graph.cc 44.7 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
Minjie Wang's avatar
Minjie Wang committed
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
    : BaseHeteroGraph(metagraph) {
66
67
68
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
69
70
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
  }
71

72
73
74
75
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
76
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
77
    adj_.data = aten::NullArray();
78
  }
79

Minjie Wang's avatar
Minjie Wang committed
80
81
  inline dgl_type_t SrcType() const {
    return 0;
82
  }
Minjie Wang's avatar
Minjie Wang committed
83
84
85
86
87
88
89

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
90
91
92
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
93
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
94
95
96
97
98
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
99
    LOG(FATAL) << "UnitGraph graph is not mutable.";
100
101
102
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
103
    LOG(FATAL) << "UnitGraph graph is not mutable.";
104
105
106
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "UnitGraph graph is not mutable.";
108
109
110
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
111
    LOG(FATAL) << "UnitGraph graph is not mutable.";
112
113
  }

114
115
116
117
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

118
119
120
121
122
123
124
125
  DLContext Context() const override {
    return adj_.row->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

  COO CopyTo(const DLContext& ctx) const {
    if (Context() == ctx)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        adj_.row.CopyTo(ctx),
        adj_.col.CopyTo(ctx));
    return ret;
  }

150
  bool IsMultigraph() const override {
151
    return aten::COOHasDuplicate(adj_);
152
153
154
155
156
157
158
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
159
    if (vtype == SrcType()) {
160
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
161
    } else if (vtype == DstType()) {
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
183
184
185
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
186
187
188
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
189
190
191
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
192
193
194
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
195
196
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
197
198
199
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
200
201
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
202
203
204
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
205
206
207
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetData(adj_, src, dst);
208
209
210
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
211
212
213
214
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
215
216
217
218
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
219
220
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
221
222
223
224
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
225
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
226
227
228
229
230
231
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
232
233
234
235
236
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
237
238
239
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
240
241
242
243
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
244
245
246
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
247
248
249
250
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
251
252
253
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
254
255
256
257
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
258
259
260
261
262
263
264
265
266
267
268
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
269
270
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
271
272
273
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
274
275
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
276
277
278
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
279
280
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
281
282
283
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
284
285
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
    LOG(FATAL) << "Not enabled for COO graph";
334
    return SparseFormat::kAny;
335
336
  }

337
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
338
339
340
341
342
343
344
345
346
347
348
349
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
364
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
365
366
367
368
369
370
371
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(0), NumBits(), Context()));
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(1), NumBits(), Context()));
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
372
          meta_graph(), NumVertices(0), NumVertices(1), new_src, new_dst);
373
374
375
376
377
378
379
380
381
      subg.induced_edges = eids;
    }
    return subg;
  }

  aten::COOMatrix adj() const {
    return adj_;
  }

382
383
384
385
386
387
388
389
390
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

391
392
393
394
395
396
397
398
399
400
401
402
403
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

404
 private:
405
406
407
408
  friend class Serializer;

  COO() {}

409
410
411
412
413
414
415
416
417
418
419
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
420
class UnitGraph::CSR : public BaseHeteroGraph {
421
 public:
Minjie Wang's avatar
Minjie Wang committed
422
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
423
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
424
    : BaseHeteroGraph(metagraph) {
425
426
427
428
429
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
    CHECK(aten::IsValidIdArray(edge_ids));
    CHECK_EQ(indices->shape[0], edge_ids->shape[0])
      << "indices and edge id arrays should have the same length";
430

431
432
433
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

434
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
435
436
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
437

Minjie Wang's avatar
Minjie Wang committed
438
439
  inline dgl_type_t SrcType() const {
    return 0;
440
  }
Minjie Wang's avatar
Minjie Wang committed
441
442
443
444
445
446
447

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
448
449
450
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
451
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
452
453
454
455
456
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
457
    LOG(FATAL) << "UnitGraph graph is not mutable.";
458
459
460
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
461
    LOG(FATAL) << "UnitGraph graph is not mutable.";
462
463
464
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
465
    LOG(FATAL) << "UnitGraph graph is not mutable.";
466
467
468
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
469
    LOG(FATAL) << "UnitGraph graph is not mutable.";
470
471
  }

472
473
474
475
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

476
477
478
479
480
481
482
483
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

484
485
486
487
488
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
489
          meta_graph_,
490
491
492
493
494
495
496
497
498
499
500
501
502
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

  CSR CopyTo(const DLContext& ctx) const {
    if (Context() == ctx) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
503
          meta_graph_,
504
505
506
507
508
509
510
511
          adj_.num_rows, adj_.num_cols,
          adj_.indptr.CopyTo(ctx),
          adj_.indices.CopyTo(ctx),
          adj_.data.CopyTo(ctx));
      return ret;
    }
  }

512
  bool IsMultigraph() const override {
513
    return aten::CSRHasDuplicate(adj_);
514
515
516
517
518
519
520
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
521
    if (vtype == SrcType()) {
522
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
523
    } else if (vtype == DstType()) {
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
545
546
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
547
548
549
550
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
551
552
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
553
554
555
556
557
558
559
560
561
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
562
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
563
564
565
566
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
567
568
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
569
570
571
572
    return aten::CSRGetData(adj_, src, dst);
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
573
574
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
575
576
577
578
579
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
580
    LOG(FATAL) << "Not enabled for CSR graph.";
581
582
583
584
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
585
    LOG(FATAL) << "Not enabled for CSR graph.";
586
587
588
589
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
590
    LOG(FATAL) << "Not enabled for CSR graph.";
591
592
593
594
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
595
    LOG(FATAL) << "Not enabled for CSR graph.";
596
597
598
599
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
600
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
601
602
603
604
605
606
607
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
608
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
    const auto& coo = aten::CSRToCOO(adj_, false);
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
626
    LOG(FATAL) << "Not enabled for CSR graph.";
627
628
629
630
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
631
    LOG(FATAL) << "Not enabled for CSR graph.";
632
633
634
635
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
636
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
637
638
639
640
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
641
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
666
    LOG(FATAL) << "Not enabled for CSR graph.";
667
668
669
670
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
671
    LOG(FATAL) << "Not enabled for CSR graph.";
672
673
674
675
676
677
678
679
680
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
697
    return SparseFormat::kAny;
698
699
  }

700
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
701
702
703
704
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
705
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
706
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
707
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
Minjie Wang's avatar
Minjie Wang committed
708
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
709
710
711
712
713
714
715
716
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
717
    LOG(FATAL) << "Not enabled for CSR graph.";
718
719
720
721
722
723
724
    return {};
  }

  aten::CSRMatrix adj() const {
    return adj_;
  }

725
726
727
728
729
730
731
732
733
734
735
736
737
738
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }


739
 private:
740
741
742
743
  friend class Serializer;

  CSR() {};

744
745
746
747
748
749
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
750
// unit graph implementation
751
752
753
//
//////////////////////////////////////////////////////////

754
755
756
757
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
758
DLContext UnitGraph::Context() const {
759
760
761
  return GetAny()->Context();
}

Minjie Wang's avatar
Minjie Wang committed
762
uint8_t UnitGraph::NumBits() const {
763
764
765
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
766
bool UnitGraph::IsMultigraph() const {
767
768
769
  return GetAny()->IsMultigraph();
}

Minjie Wang's avatar
Minjie Wang committed
770
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
771
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
772
773
774
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
775
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
776
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
777
  return ptr->NumVertices(vtype);
778
779
}

Minjie Wang's avatar
Minjie Wang committed
780
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
781
782
783
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
784
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
785
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
786
  const auto ptr = GetFormat(fmt);
787
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
788
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
789
  return ptr->HasVertex(vtype, vid);
790
791
}

Minjie Wang's avatar
Minjie Wang committed
792
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
793
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
794
795
796
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
797
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
798
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
799
  const auto ptr = GetFormat(fmt);
800
  if (fmt == SparseFormat::kCSC)
801
802
803
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
804
805
}

Minjie Wang's avatar
Minjie Wang committed
806
BoolArray UnitGraph::HasEdgesBetween(
807
    dgl_type_t etype, IdArray src, IdArray dst) const {
808
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
809
  const auto ptr = GetFormat(fmt);
810
  if (fmt == SparseFormat::kCSC)
811
812
813
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
814
815
}

Minjie Wang's avatar
Minjie Wang committed
816
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
817
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
818
  const auto ptr = GetFormat(fmt);
819
  if (fmt == SparseFormat::kCSC)
820
821
822
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
823
824
}

Minjie Wang's avatar
Minjie Wang committed
825
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
826
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
827
828
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
829
830
}

Minjie Wang's avatar
Minjie Wang committed
831
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
832
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
833
  const auto ptr = GetFormat(fmt);
834
  if (fmt == SparseFormat::kCSC)
835
836
837
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
838
839
}

Minjie Wang's avatar
Minjie Wang committed
840
EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
841
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
842
  const auto ptr = GetFormat(fmt);
843
  if (fmt == SparseFormat::kCSC) {
844
    EdgeArray edges = ptr->EdgeIds(etype, dst, src);
845
846
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
847
    return ptr->EdgeIds(etype, src, dst);
848
849
850
  }
}

Minjie Wang's avatar
Minjie Wang committed
851
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
852
  const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
853
854
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
855
856
}

Minjie Wang's avatar
Minjie Wang committed
857
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
858
  const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
859
860
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
861
862
}

Minjie Wang's avatar
Minjie Wang committed
863
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
864
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
865
  const auto ptr = GetFormat(fmt);
866
  if (fmt == SparseFormat::kCSC) {
867
868
869
870
871
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
872
873
}

Minjie Wang's avatar
Minjie Wang committed
874
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
875
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
876
  const auto ptr = GetFormat(fmt);
877
  if (fmt == SparseFormat::kCSC) {
878
879
880
881
882
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
883
884
}

Minjie Wang's avatar
Minjie Wang committed
885
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
886
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
887
888
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
889
890
}

Minjie Wang's avatar
Minjie Wang committed
891
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
892
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
893
894
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
895
896
}

Minjie Wang's avatar
Minjie Wang committed
897
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
898
899
  SparseFormat fmt;
  if (order == std::string("eid")) {
900
    fmt = SelectFormat(SparseFormat::kCOO);
901
  } else if (order.empty()) {
902
    // arbitrary order
903
    fmt = SelectFormat(SparseFormat::kAny);
904
  } else if (order == std::string("srcdst")) {
905
    fmt = SelectFormat(SparseFormat::kCSR);
906
907
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
908
    return {};
909
  }
910
911

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
912
  if (fmt == SparseFormat::kCSC)
913
914
915
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
916
917
}

Minjie Wang's avatar
Minjie Wang committed
918
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
919
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
920
  const auto ptr = GetFormat(fmt);
921
  if (fmt == SparseFormat::kCSC)
922
923
924
    return ptr->OutDegree(etype, vid);
  else
    return ptr->InDegree(etype, vid);
925
926
}

Minjie Wang's avatar
Minjie Wang committed
927
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
928
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
929
  const auto ptr = GetFormat(fmt);
930
  if (fmt == SparseFormat::kCSC)
931
932
933
    return ptr->OutDegrees(etype, vids);
  else
    return ptr->InDegrees(etype, vids);
934
935
}

Minjie Wang's avatar
Minjie Wang committed
936
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
937
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
938
939
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegree(etype, vid);
940
941
}

Minjie Wang's avatar
Minjie Wang committed
942
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
943
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
944
945
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegrees(etype, vids);
946
947
}

Minjie Wang's avatar
Minjie Wang committed
948
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
949
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
950
951
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
952
953
}

Minjie Wang's avatar
Minjie Wang committed
954
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
955
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
956
957
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
958
959
}

Minjie Wang's avatar
Minjie Wang committed
960
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
961
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
962
  const auto ptr = GetFormat(fmt);
963
  if (fmt == SparseFormat::kCSC)
964
965
966
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
967
968
}

Minjie Wang's avatar
Minjie Wang committed
969
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
970
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
971
  const auto ptr = GetFormat(fmt);
972
  if (fmt == SparseFormat::kCSC)
973
974
975
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
976
977
}

Minjie Wang's avatar
Minjie Wang committed
978
std::vector<IdArray> UnitGraph::GetAdj(
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(etype, false, "csr")
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(etype, !transpose, fmt);
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
998
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
999
  // We prefer to generate a subgraph from out-csr.
1000
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
1001
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1002
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1023
1024
1025
1026
1027
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1028
HeteroSubgraph UnitGraph::EdgeSubgraph(
1029
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1030
  SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
1031
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1032
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1053
1054
1055
1056
1057
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1058
HeteroGraphPtr UnitGraph::CreateFromCOO(
1059
1060
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1061
    SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
1062
1063
1064
1065
1066
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, num_src, num_dst, row, col));
1067
1068
1069

  return HeteroGraphPtr(
      new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
1070
1071
}

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
  return HeteroGraphPtr(
      new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1084
1085
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1086
    IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
1087
1088
1089
1090
1091
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1092
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
1093
1094
}

1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
}

1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1128
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1129
1130
1131
1132
1133
1134
1135
  if (g->NumBits() == bits) {
    return g;
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1136
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1137
    CHECK_NOTNULL(bg);
1138

1139
1140
    CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
1141
1142
    return HeteroGraphPtr(
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1143
1144
1145
  }
}

Minjie Wang's avatar
Minjie Wang committed
1146
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
1147
1148
1149
1150
1151
1152
1153
  if (ctx == g->Context()) {
    return g;
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1154
  auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1155
  CHECK_NOTNULL(bg);
1156

1157
1158
  CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
1159
1160
  return HeteroGraphPtr(
      new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1161
1162
}

1163
1164
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
                     SparseFormat restrict_format)
Minjie Wang's avatar
Minjie Wang committed
1165
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1166
1167
1168
1169
1170
  restrict_format_ = restrict_format;

  // If the graph is hypersparse and in COO format, switch the restricted format to COO.
  // If the graph is given as CSR, the indptr array is already materialized so we don't
  // care about restricting conversion anyway (even if it is hypersparse).
1171
  if (restrict_format == SparseFormat::kAny) {
1172
    if (coo && coo->IsHypersparse())
1173
      restrict_format_ = SparseFormat::kCOO;
1174
1175
  }

1176
1177
1178
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
HeteroGraphPtr UnitGraph::CreateHomographFrom(
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
    SparseFormat restrict_format) {
  auto mg = CreateUnitGraphMetaGraph1();

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));

  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1203
UnitGraph::CSRPtr UnitGraph::GetInCSR() const {
1204
1205
1206
  if (!in_csr_) {
    if (out_csr_) {
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1207
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1208
1209
1210
1211
1212
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& adj = coo_->adj();
      const auto& newadj = aten::COOToCSR(
          aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
Minjie Wang's avatar
Minjie Wang committed
1213
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1214
1215
1216
1217
1218
1219
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
Minjie Wang's avatar
Minjie Wang committed
1220
UnitGraph::CSRPtr UnitGraph::GetOutCSR() const {
1221
1222
1223
  if (!out_csr_) {
    if (in_csr_) {
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1224
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1225
1226
1227
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& newadj = aten::COOToCSR(coo_->adj());
Minjie Wang's avatar
Minjie Wang committed
1228
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1229
1230
1231
1232
1233
1234
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
Minjie Wang's avatar
Minjie Wang committed
1235
UnitGraph::COOPtr UnitGraph::GetCOO() const {
1236
1237
  if (!coo_) {
    if (in_csr_) {
1238
1239
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
1240
1241
1242
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
Minjie Wang's avatar
Minjie Wang committed
1243
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
1244
1245
1246
1247
1248
    }
  }
  return coo_;
}

1249
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1250
1251
1252
  return GetInCSR()->adj();
}

1253
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1254
1255
1256
  return GetOutCSR()->adj();
}

1257
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1258
1259
1260
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1261
HeteroGraphPtr UnitGraph::GetAny() const {
1262
1263
1264
1265
1266
1267
1268
1269
1270
  if (in_csr_) {
    return in_csr_;
  } else if (out_csr_) {
    return out_csr_;
  } else {
    return coo_;
  }
}

1271
1272
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
  case SparseFormat::kCOO:
    return GetCOO();
  case SparseFormat::kAny:
    return GetAny();
  default:
    LOG(FATAL) << "unsupported format code";
    return nullptr;
1284
1285
1286
1287
  }
}

SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
1288
  if (restrict_format_ != SparseFormat::kAny)
1289
    return restrict_format_;
1290
  else if (preferred_format != SparseFormat::kAny)
1291
1292
    return preferred_format;
  else if (in_csr_)
1293
    return SparseFormat::kCSC;
1294
  else if (out_csr_)
1295
    return SparseFormat::kCSR;
1296
  else
1297
    return SparseFormat::kCOO;
1298
1299
}

1300
1301
1302
1303
1304
1305
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
  if (in_csr_) {
    aten::CSRMatrix csc = GetCSCMatrix(0);
1306
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1307
1308
1309
  }
  if (out_csr_) {
    aten::CSRMatrix csr = GetCSRMatrix(0);
1310
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1311
1312
1313
1314
  }
  if (coo_) {
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1315
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1316
1317
1318
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1319
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1320
1321
1322
1323
1324
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1325
1326
1327
1328
1329
1330
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1331
1332
1333
1334
1335
1336

  int64_t format_code;
  CHECK(fs->Read(&format_code)) << "Invalid format";
  restrict_format_ = static_cast<SparseFormat>(format_code);

  switch (restrict_format_) {
1337
    case SparseFormat::kCOO:
1338
1339
      fs->Read(&coo_);
      break;
1340
    case SparseFormat::kCSR:
1341
1342
      fs->Read(&out_csr_);
      break;
1343
    case SparseFormat::kCSC:
1344
1345
1346
1347
1348
1349
1350
1351
1352
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

  meta_graph_ = GetAny()->meta_graph();

1353
1354
1355
  return true;
}

1356

1357
1358
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1359
1360
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1361
  auto avail_fmt = SelectFormat(SparseFormat::kAny);
1362
1363
  fs->Write(static_cast<int64_t>(avail_fmt));
  switch (avail_fmt) {
1364
    case SparseFormat::kCOO:
1365
1366
      fs->Write(GetCOO());
      break;
1367
    case SparseFormat::kCSR:
1368
1369
      fs->Write(GetOutCSR());
      break;
1370
    case SparseFormat::kCSC:
1371
1372
1373
1374
1375
1376
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1377
1378
}

1379
}  // namespace dgl