unit_graph.cc 45.5 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/base_heterograph.h>
8
9
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
10
11

#include "../c_api_common.h"
12
#include "./unit_graph.h"
13
14

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
Minjie Wang's avatar
Minjie Wang committed
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
    : BaseHeteroGraph(metagraph) {
66
67
68
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
69
70
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
  }
71

72
73
74
75
  COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {
    // Data index should not be inherited. Edges in COO format are always
    // assigned ids from 0 to num_edges - 1.
76
    CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
77
    adj_.data = aten::NullArray();
78
  }
79

Minjie Wang's avatar
Minjie Wang committed
80
81
  inline dgl_type_t SrcType() const {
    return 0;
82
  }
Minjie Wang's avatar
Minjie Wang committed
83
84
85
86
87
88
89

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
90
91
92
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
93
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
94
95
96
97
98
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
99
    LOG(FATAL) << "UnitGraph graph is not mutable.";
100
101
102
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
103
    LOG(FATAL) << "UnitGraph graph is not mutable.";
104
105
106
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
107
    LOG(FATAL) << "UnitGraph graph is not mutable.";
108
109
110
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
111
    LOG(FATAL) << "UnitGraph graph is not mutable.";
112
113
  }

114
115
116
117
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

118
119
120
121
122
123
124
125
  DLContext Context() const override {
    return adj_.row->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    return ret;
  }

  COO CopyTo(const DLContext& ctx) const {
    if (Context() == ctx)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        adj_.row.CopyTo(ctx),
        adj_.col.CopyTo(ctx));
    return ret;
  }

150
  bool IsMultigraph() const override {
151
    return aten::COOHasDuplicate(adj_);
152
153
154
155
156
157
158
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
159
    if (vtype == SrcType()) {
160
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
161
    } else if (vtype == DstType()) {
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
183
184
185
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
186
187
188
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
189
190
191
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
192
193
194
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
195
196
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
197
198
199
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
200
201
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
202
203
204
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
205
206
207
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetData(adj_, src, dst);
208
209
210
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
211
212
213
214
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
215
216
217
218
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
219
220
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
221
222
223
224
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
225
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
226
227
228
229
230
231
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
232
233
234
235
236
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
237
238
239
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
240
241
242
243
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
244
245
246
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
247
248
249
250
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
251
252
253
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
254
255
256
257
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
258
259
260
261
262
263
264
265
266
267
268
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
269
270
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
271
272
273
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
274
275
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
276
277
278
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
279
280
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
281
282
283
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
284
285
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return aten::CSRMatrix();
  }

  SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
    LOG(FATAL) << "Not enabled for COO graph";
334
    return SparseFormat::kAny;
335
336
  }

337
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
338
339
340
341
342
343
344
345
346
347
348
349
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
350
351
352
353
354
355
356
357
358
359
360
361
362
363
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
364
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
365
366
367
368
369
370
371
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(0), NumBits(), Context()));
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(1), NumBits(), Context()));
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
372
          meta_graph(), NumVertices(0), NumVertices(1), new_src, new_dst);
373
374
375
376
377
378
379
380
381
      subg.induced_edges = eids;
    }
    return subg;
  }

  aten::COOMatrix adj() const {
    return adj_;
  }

382
383
384
385
386
387
388
389
390
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

391
392
393
394
395
396
397
398
399
400
401
402
403
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }

404
 private:
405
406
407
408
  friend class Serializer;

  COO() {}

409
410
411
412
413
414
415
416
417
418
419
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
420
class UnitGraph::CSR : public BaseHeteroGraph {
421
 public:
Minjie Wang's avatar
Minjie Wang committed
422
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
423
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
424
    : BaseHeteroGraph(metagraph) {
425
426
427
428
429
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
    CHECK(aten::IsValidIdArray(edge_ids));
    CHECK_EQ(indices->shape[0], edge_ids->shape[0])
      << "indices and edge id arrays should have the same length";
430

431
432
433
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
  }

434
  CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
435
436
    : BaseHeteroGraph(metagraph), adj_(csr) {
  }
437

Minjie Wang's avatar
Minjie Wang committed
438
439
  inline dgl_type_t SrcType() const {
    return 0;
440
  }
Minjie Wang's avatar
Minjie Wang committed
441
442
443
444
445
446
447

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
448
449
450
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
451
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
452
453
454
455
456
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
457
    LOG(FATAL) << "UnitGraph graph is not mutable.";
458
459
460
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
461
    LOG(FATAL) << "UnitGraph graph is not mutable.";
462
463
464
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
465
    LOG(FATAL) << "UnitGraph graph is not mutable.";
466
467
468
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
469
    LOG(FATAL) << "UnitGraph graph is not mutable.";
470
471
  }

472
473
474
475
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

476
477
478
479
480
481
482
483
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

484
485
486
487
488
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
489
          meta_graph_,
490
491
492
493
494
495
496
497
498
499
500
501
502
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      return ret;
    }
  }

  CSR CopyTo(const DLContext& ctx) const {
    if (Context() == ctx) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
503
          meta_graph_,
504
505
506
507
508
509
510
511
          adj_.num_rows, adj_.num_cols,
          adj_.indptr.CopyTo(ctx),
          adj_.indices.CopyTo(ctx),
          adj_.data.CopyTo(ctx));
      return ret;
    }
  }

512
  bool IsMultigraph() const override {
513
    return aten::CSRHasDuplicate(adj_);
514
515
516
517
518
519
520
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
521
    if (vtype == SrcType()) {
522
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
523
    } else if (vtype == DstType()) {
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
545
546
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
547
548
549
550
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
551
552
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
553
554
555
556
557
558
559
560
561
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
562
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
563
564
565
566
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
567
568
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
569
570
571
572
    return aten::CSRGetData(adj_, src, dst);
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
573
574
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
575
576
577
578
579
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
580
    LOG(FATAL) << "Not enabled for CSR graph.";
581
582
583
584
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
585
    LOG(FATAL) << "Not enabled for CSR graph.";
586
587
588
589
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
590
    LOG(FATAL) << "Not enabled for CSR graph.";
591
592
593
594
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
595
    LOG(FATAL) << "Not enabled for CSR graph.";
596
597
598
599
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
600
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
601
602
603
604
605
606
607
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
608
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
    const auto& coo = aten::CSRToCOO(adj_, false);
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
626
    LOG(FATAL) << "Not enabled for CSR graph.";
627
628
629
630
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
631
    LOG(FATAL) << "Not enabled for CSR graph.";
632
633
634
635
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
636
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
637
638
639
640
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
641
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
642
643
644
645
646
647
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
648
    CHECK_EQ(NumBits(), 64);
649
650
651
652
653
654
655
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

656
657
658
659
660
661
662
663
664
665
  DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
    const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
    const int32_t start = indptr_data[vid];
    const int32_t end = indptr_data[vid + 1];
    return DGLIdIters32(indices_data + start, indices_data + end);
  }

666
667
668
  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
669
    CHECK_EQ(NumBits(), 64);
670
671
672
673
674
675
676
677
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
678
    LOG(FATAL) << "Not enabled for CSR graph.";
679
680
681
682
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
683
    LOG(FATAL) << "Not enabled for CSR graph.";
684
685
686
687
688
689
690
691
692
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
  aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::COOMatrix();
  }

  aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
    return aten::CSRMatrix();
  }

  aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
    return adj_;
  }

  SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
    LOG(FATAL) << "Not enabled for CSR graph";
709
    return SparseFormat::kAny;
710
711
  }

712
  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
713
714
715
716
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
717
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
718
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
719
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
Minjie Wang's avatar
Minjie Wang committed
720
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
721
722
723
724
725
726
727
728
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
729
    LOG(FATAL) << "Not enabled for CSR graph.";
730
731
732
733
734
735
736
    return {};
  }

  aten::CSRMatrix adj() const {
    return adj_;
  }

737
738
739
740
741
742
743
744
745
746
747
748
749
750
  bool Load(dmlc::Stream* fs) {
    auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
    CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
    meta_graph_ = meta_imgraph;
    CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
    return true;
  }
  void Save(dmlc::Stream* fs) const {
    auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
    fs->Write(meta_graph_ptr);
    fs->Write(adj_);
  }


751
 private:
752
753
754
755
  friend class Serializer;

  CSR() {};

756
757
758
759
760
761
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
762
// unit graph implementation
763
764
765
//
//////////////////////////////////////////////////////////

766
767
768
769
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
770
DLContext UnitGraph::Context() const {
771
772
773
  return GetAny()->Context();
}

Minjie Wang's avatar
Minjie Wang committed
774
uint8_t UnitGraph::NumBits() const {
775
776
777
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
778
bool UnitGraph::IsMultigraph() const {
779
780
781
  return GetAny()->IsMultigraph();
}

Minjie Wang's avatar
Minjie Wang committed
782
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
783
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
784
785
786
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
787
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
788
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
789
  return ptr->NumVertices(vtype);
790
791
}

Minjie Wang's avatar
Minjie Wang committed
792
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
793
794
795
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
796
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
797
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
798
  const auto ptr = GetFormat(fmt);
799
  if (fmt == SparseFormat::kCSC)
Minjie Wang's avatar
Minjie Wang committed
800
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
801
  return ptr->HasVertex(vtype, vid);
802
803
}

Minjie Wang's avatar
Minjie Wang committed
804
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
805
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
806
807
808
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
809
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
810
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
811
  const auto ptr = GetFormat(fmt);
812
  if (fmt == SparseFormat::kCSC)
813
814
815
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
816
817
}

Minjie Wang's avatar
Minjie Wang committed
818
BoolArray UnitGraph::HasEdgesBetween(
819
    dgl_type_t etype, IdArray src, IdArray dst) const {
820
  const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
821
  const auto ptr = GetFormat(fmt);
822
  if (fmt == SparseFormat::kCSC)
823
824
825
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
826
827
}

Minjie Wang's avatar
Minjie Wang committed
828
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
829
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
830
  const auto ptr = GetFormat(fmt);
831
  if (fmt == SparseFormat::kCSC)
832
833
834
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
835
836
}

Minjie Wang's avatar
Minjie Wang committed
837
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
838
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
839
840
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
841
842
}

Minjie Wang's avatar
Minjie Wang committed
843
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
844
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
845
  const auto ptr = GetFormat(fmt);
846
  if (fmt == SparseFormat::kCSC)
847
848
849
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
850
851
}

Minjie Wang's avatar
Minjie Wang committed
852
EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
853
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
854
  const auto ptr = GetFormat(fmt);
855
  if (fmt == SparseFormat::kCSC) {
856
    EdgeArray edges = ptr->EdgeIds(etype, dst, src);
857
858
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
859
    return ptr->EdgeIds(etype, src, dst);
860
861
862
  }
}

Minjie Wang's avatar
Minjie Wang committed
863
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
864
  const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
865
866
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
867
868
}

Minjie Wang's avatar
Minjie Wang committed
869
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
870
  const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
871
872
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
873
874
}

Minjie Wang's avatar
Minjie Wang committed
875
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
876
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
877
  const auto ptr = GetFormat(fmt);
878
  if (fmt == SparseFormat::kCSC) {
879
880
881
882
883
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
884
885
}

Minjie Wang's avatar
Minjie Wang committed
886
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
887
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
888
  const auto ptr = GetFormat(fmt);
889
  if (fmt == SparseFormat::kCSC) {
890
891
892
893
894
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
895
896
}

Minjie Wang's avatar
Minjie Wang committed
897
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
898
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
899
900
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
901
902
}

Minjie Wang's avatar
Minjie Wang committed
903
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
904
  const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
905
906
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
907
908
}

Minjie Wang's avatar
Minjie Wang committed
909
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
910
911
  SparseFormat fmt;
  if (order == std::string("eid")) {
912
    fmt = SelectFormat(SparseFormat::kCOO);
913
  } else if (order.empty()) {
914
    // arbitrary order
915
    fmt = SelectFormat(SparseFormat::kAny);
916
  } else if (order == std::string("srcdst")) {
917
    fmt = SelectFormat(SparseFormat::kCSR);
918
919
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
920
    return {};
921
  }
922
923

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
924
  if (fmt == SparseFormat::kCSC)
925
926
927
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
928
929
}

Minjie Wang's avatar
Minjie Wang committed
930
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
931
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
932
  const auto ptr = GetFormat(fmt);
933
  if (fmt == SparseFormat::kCSC)
934
935
936
    return ptr->OutDegree(etype, vid);
  else
    return ptr->InDegree(etype, vid);
937
938
}

Minjie Wang's avatar
Minjie Wang committed
939
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
940
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
941
  const auto ptr = GetFormat(fmt);
942
  if (fmt == SparseFormat::kCSC)
943
944
945
    return ptr->OutDegrees(etype, vids);
  else
    return ptr->InDegrees(etype, vids);
946
947
}

Minjie Wang's avatar
Minjie Wang committed
948
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
949
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
950
951
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegree(etype, vid);
952
953
}

Minjie Wang's avatar
Minjie Wang committed
954
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
955
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
956
957
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegrees(etype, vids);
958
959
}

Minjie Wang's avatar
Minjie Wang committed
960
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
961
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
962
963
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
964
965
}

966
967
968
969
970
971
972
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
  const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
  CHECK_NOTNULL(ptr);
  return ptr->SuccVec32(etype, vid);
}

Minjie Wang's avatar
Minjie Wang committed
973
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
974
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
975
976
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
977
978
}

Minjie Wang's avatar
Minjie Wang committed
979
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
980
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
981
  const auto ptr = GetFormat(fmt);
982
  if (fmt == SparseFormat::kCSC)
983
984
985
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
986
987
}

Minjie Wang's avatar
Minjie Wang committed
988
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
989
  SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
990
  const auto ptr = GetFormat(fmt);
991
  if (fmt == SparseFormat::kCSC)
992
993
994
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
995
996
}

Minjie Wang's avatar
Minjie Wang committed
997
std::vector<IdArray> UnitGraph::GetAdj(
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(etype, false, "csr")
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(etype, !transpose, fmt);
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
1017
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
1018
  // We prefer to generate a subgraph from out-csr.
1019
  SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
1020
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
1021
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1042
1043
1044
1045
1046
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1047
HeteroSubgraph UnitGraph::EdgeSubgraph(
1048
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
1049
  SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
1050
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
1051
  HeteroSubgraph ret;
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

  CSRPtr subcsr = nullptr;
  CSRPtr subcsc = nullptr;
  COOPtr subcoo = nullptr;
  switch (fmt) {
    case SparseFormat::kCSR:
      subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCSC:
      subcsc = std::dynamic_pointer_cast<CSR>(sg.graph);
      break;
    case SparseFormat::kCOO:
      subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
      break;
    default:
      LOG(FATAL) << "[BUG] unsupported format " << static_cast<int>(fmt);
      return ret;
  }

  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
1072
1073
1074
1075
1076
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
1077
HeteroGraphPtr UnitGraph::CreateFromCOO(
1078
1079
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray row, IdArray col,
1080
    SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
1081
1082
1083
1084
1085
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, num_src, num_dst, row, col));
1086
1087
1088

  return HeteroGraphPtr(
      new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
1089
1090
}

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
HeteroGraphPtr UnitGraph::CreateFromCOO(
    int64_t num_vtypes, const aten::COOMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, mat));
  return HeteroGraphPtr(
      new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1103
1104
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
1105
    IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
1106
1107
1108
1109
1110
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1111
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
1112
1113
}

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, mat));
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
}

1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
    IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}

HeteroGraphPtr UnitGraph::CreateFromCSC(
    int64_t num_vtypes, const aten::CSRMatrix& mat,
    SparseFormat restrict_format) {
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(mat.num_rows, mat.num_cols);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csc(new CSR(mg, mat));
  return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1147
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1148
1149
1150
1151
1152
1153
1154
  if (g->NumBits() == bits) {
    return g;
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1155
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1156
    CHECK_NOTNULL(bg);
1157

1158
1159
    CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
1160
1161
    return HeteroGraphPtr(
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1162
1163
1164
  }
}

Minjie Wang's avatar
Minjie Wang committed
1165
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
1166
1167
1168
1169
1170
1171
1172
  if (ctx == g->Context()) {
    return g;
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1173
  auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1174
  CHECK_NOTNULL(bg);
1175

1176
1177
  CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
1178
1179
  return HeteroGraphPtr(
      new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1180
1181
}

1182
1183
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
                     SparseFormat restrict_format)
Minjie Wang's avatar
Minjie Wang committed
1184
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1185
1186
1187
1188
1189
  restrict_format_ = restrict_format;

  // If the graph is hypersparse and in COO format, switch the restricted format to COO.
  // If the graph is given as CSR, the indptr array is already materialized so we don't
  // care about restricting conversion anyway (even if it is hypersparse).
1190
  if (restrict_format == SparseFormat::kAny) {
1191
    if (coo && coo->IsHypersparse())
1192
      restrict_format_ = SparseFormat::kCOO;
1193
1194
  }

1195
1196
1197
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
HeteroGraphPtr UnitGraph::CreateHomographFrom(
    const aten::CSRMatrix &in_csr,
    const aten::CSRMatrix &out_csr,
    const aten::COOMatrix &coo,
    bool has_in_csr,
    bool has_out_csr,
    bool has_coo,
    SparseFormat restrict_format) {
  auto mg = CreateUnitGraphMetaGraph1();

  CSRPtr in_csr_ptr = nullptr;
  CSRPtr out_csr_ptr = nullptr;
  COOPtr coo_ptr = nullptr;

  if (has_in_csr)
    in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
  if (has_out_csr)
    out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
  if (has_coo)
    coo_ptr = COOPtr(new COO(mg, coo));

  return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
}

Minjie Wang's avatar
Minjie Wang committed
1222
UnitGraph::CSRPtr UnitGraph::GetInCSR() const {
1223
1224
1225
  if (!in_csr_) {
    if (out_csr_) {
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1226
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1227
1228
1229
1230
1231
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& adj = coo_->adj();
      const auto& newadj = aten::COOToCSR(
          aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
Minjie Wang's avatar
Minjie Wang committed
1232
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1233
1234
1235
1236
1237
1238
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
Minjie Wang's avatar
Minjie Wang committed
1239
UnitGraph::CSRPtr UnitGraph::GetOutCSR() const {
1240
1241
1242
  if (!out_csr_) {
    if (in_csr_) {
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1243
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1244
1245
1246
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& newadj = aten::COOToCSR(coo_->adj());
Minjie Wang's avatar
Minjie Wang committed
1247
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1248
1249
1250
1251
1252
1253
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
Minjie Wang's avatar
Minjie Wang committed
1254
UnitGraph::COOPtr UnitGraph::GetCOO() const {
1255
1256
  if (!coo_) {
    if (in_csr_) {
1257
1258
      const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
1259
1260
1261
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
Minjie Wang's avatar
Minjie Wang committed
1262
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
1263
1264
1265
1266
1267
    }
  }
  return coo_;
}

1268
aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
1269
1270
1271
  return GetInCSR()->adj();
}

1272
aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
1273
1274
1275
  return GetOutCSR()->adj();
}

1276
aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
1277
1278
1279
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1280
HeteroGraphPtr UnitGraph::GetAny() const {
1281
1282
1283
1284
1285
1286
1287
1288
1289
  if (in_csr_) {
    return in_csr_;
  } else if (out_csr_) {
    return out_csr_;
  } else {
    return coo_;
  }
}

1290
1291
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
  case SparseFormat::kCSR:
    return GetOutCSR();
  case SparseFormat::kCSC:
    return GetInCSR();
  case SparseFormat::kCOO:
    return GetCOO();
  case SparseFormat::kAny:
    return GetAny();
  default:
    LOG(FATAL) << "unsupported format code";
    return nullptr;
1303
1304
1305
1306
  }
}

SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
1307
  if (restrict_format_ != SparseFormat::kAny)
1308
    return restrict_format_;
1309
  else if (preferred_format != SparseFormat::kAny)
1310
1311
    return preferred_format;
  else if (in_csr_)
1312
    return SparseFormat::kCSC;
1313
  else if (out_csr_)
1314
    return SparseFormat::kCSR;
1315
  else
1316
    return SparseFormat::kCOO;
1317
1318
}

1319
1320
1321
1322
1323
1324
GraphPtr UnitGraph::AsImmutableGraph() const {
  CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
  dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
  dgl::COOPtr coo_ptr = nullptr;
  if (in_csr_) {
    aten::CSRMatrix csc = GetCSCMatrix(0);
1325
    in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
1326
1327
1328
  }
  if (out_csr_) {
    aten::CSRMatrix csr = GetCSRMatrix(0);
1329
    out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
1330
1331
1332
1333
  }
  if (coo_) {
    aten::COOMatrix coo = GetCOOMatrix(0);
    if (!COOHasData(coo)) {
1334
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
1335
1336
1337
    } else {
      IdArray new_src = Scatter(coo.row, coo.data);
      IdArray new_dst = Scatter(coo.col, coo.data);
1338
      coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
1339
1340
1341
1342
1343
    }
  }
  return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
}

1344
1345
1346
1347
1348
1349
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
1350
1351
1352
1353
1354
1355

  int64_t format_code;
  CHECK(fs->Read(&format_code)) << "Invalid format";
  restrict_format_ = static_cast<SparseFormat>(format_code);

  switch (restrict_format_) {
1356
    case SparseFormat::kCOO:
1357
1358
      fs->Read(&coo_);
      break;
1359
    case SparseFormat::kCSR:
1360
1361
      fs->Read(&out_csr_);
      break;
1362
    case SparseFormat::kCSC:
1363
1364
1365
1366
1367
1368
1369
1370
1371
      fs->Read(&in_csr_);
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }

  meta_graph_ = GetAny()->meta_graph();

1372
1373
1374
  return true;
}

1375

1376
1377
void UnitGraph::Save(dmlc::Stream* fs) const {
  fs->Write(kDGLSerialize_UnitGraphMagic);
1378
1379
  // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
  // sparse matrix
1380
  auto avail_fmt = SelectFormat(SparseFormat::kAny);
1381
1382
  fs->Write(static_cast<int64_t>(avail_fmt));
  switch (avail_fmt) {
1383
    case SparseFormat::kCOO:
1384
1385
      fs->Write(GetCOO());
      break;
1386
    case SparseFormat::kCSR:
1387
1388
      fs->Write(GetOutCSR());
      break;
1389
    case SparseFormat::kCSC:
1390
1391
1392
1393
1394
1395
      fs->Write(GetInCSR());
      break;
    default:
      LOG(FATAL) << "unsupported format code";
      break;
  }
1396
1397
}

1398
}  // namespace dgl