unit_graph.cc 39.9 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
 * \file graph/unit_graph.cc
 * \brief UnitGraph graph implementation
5
6
7
8
 */
#include <dgl/array.h>
#include <dgl/lazy.h>
#include <dgl/immutable_graph.h>
9
#include <dgl/base_heterograph.h>
10

Minjie Wang's avatar
Minjie Wang committed
11
#include "./unit_graph.h"
12
13
14
#include "../c_api_common.h"

namespace dgl {
15

16
namespace {
17
18
19

using namespace dgl::aten;

Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
  // a self-loop edge 0->0
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 0);
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
  return g;
}
30

Minjie Wang's avatar
Minjie Wang committed
31
32
33
34
35
// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
  // an edge 0->1
  std::vector<int64_t> row_vec(1, 0);
  std::vector<int64_t> col_vec(1, 1);
36
37
38
39
40
  IdArray row = aten::VecToIdArray(row_vec);
  IdArray col = aten::VecToIdArray(col_vec);
  GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
  return g;
}
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52

inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
  static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
  static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
  if (num_vtypes == 1)
    return mg1;
  else if (num_vtypes == 2)
    return mg2;
  else
    LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
  return {};
}
53
54

};  // namespace
55
56
57
58
59
60
61

//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////

Minjie Wang's avatar
Minjie Wang committed
62
class UnitGraph::COO : public BaseHeteroGraph {
63
 public:
Minjie Wang's avatar
Minjie Wang committed
64
65
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
    : BaseHeteroGraph(metagraph) {
66
67
68
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
69
70
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
  }
71

Minjie Wang's avatar
Minjie Wang committed
72
73
74
  COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
      IdArray src, IdArray dst, bool is_multigraph)
    : BaseHeteroGraph(metagraph),
75
      is_multigraph_(is_multigraph) {
76
77
78
    CHECK(aten::IsValidIdArray(src));
    CHECK(aten::IsValidIdArray(dst));
    CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
79
80
    adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
  }
81

Minjie Wang's avatar
Minjie Wang committed
82
83
  explicit COO(GraphPtr metagraph, const aten::COOMatrix& coo)
    : BaseHeteroGraph(metagraph), adj_(coo) {}
84

Minjie Wang's avatar
Minjie Wang committed
85
86
  inline dgl_type_t SrcType() const {
    return 0;
87
  }
Minjie Wang's avatar
Minjie Wang committed
88
89
90
91
92
93
94

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
95
96
97
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
98
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
99
100
101
102
103
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
104
    LOG(FATAL) << "UnitGraph graph is not mutable.";
105
106
107
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
108
    LOG(FATAL) << "UnitGraph graph is not mutable.";
109
110
111
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
112
    LOG(FATAL) << "UnitGraph graph is not mutable.";
113
114
115
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
116
    LOG(FATAL) << "UnitGraph graph is not mutable.";
117
118
  }

119
120
121
122
  DLDataType DataType() const override {
    return adj_.row->dtype;
  }

123
124
125
126
127
128
129
130
  DLContext Context() const override {
    return adj_.row->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.row->dtype.bits;
  }

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  COO AsNumBits(uint8_t bits) const {
    if (NumBits() == bits)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }

  COO CopyTo(const DLContext& ctx) const {
    if (Context() == ctx)
      return *this;

    COO ret(
        meta_graph_,
        adj_.num_rows, adj_.num_cols,
        adj_.row.CopyTo(ctx),
        adj_.col.CopyTo(ctx));
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }

157
158
159
160
161
162
163
164
165
166
167
  bool IsMultigraph() const override {
    return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
        return aten::COOHasDuplicate(adj_);
      });
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
168
    if (vtype == SrcType()) {
169
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
170
    } else if (vtype == DstType()) {
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.row->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
192
193
194
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOIsNonZero(adj_, src, dst);
195
196
197
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
198
199
200
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
    return aten::COOIsNonZero(adj_, src_ids, dst_ids);
201
202
203
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
204
205
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
206
207
208
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
209
210
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    return aten::COOGetRowDataAndIndices(adj_, src).second;
211
212
213
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
214
215
216
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
    return aten::COOGetData(adj_, src, dst);
217
218
219
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
220
221
222
223
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
    const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
224
225
226
227
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
228
229
    const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
    const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
230
231
232
233
    return std::pair<dgl_id_t, dgl_id_t>(src, dst);
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
234
    CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
235
236
237
238
239
240
    return EdgeArray{aten::IndexSelect(adj_.row, eids),
                     aten::IndexSelect(adj_.col, eids),
                     eids};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
241
242
243
244
245
    IdArray ret_src, ret_eid;
    std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
        aten::COOTranspose(adj_), vid);
    IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
246
247
248
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
249
250
251
252
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{coosubmat.col, row, coosubmat.data};
253
254
255
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
256
257
258
259
    IdArray ret_dst, ret_eid;
    std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
260
261
262
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
263
264
265
266
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    auto coosubmat = aten::COOSliceRows(adj_, vids);
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
267
268
269
270
271
272
273
274
275
276
277
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("eid"))
      << "COO only support Edges of order \"eid\", but got \""
      << order << "\".";
    IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
    return EdgeArray{adj_.row, adj_.col, rst_eid};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
278
279
    CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
280
281
282
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
283
284
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
285
286
287
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
288
289
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
    return aten::COOGetRowNNZ(adj_, vid);
290
291
292
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
293
294
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
    return aten::COOGetRowNNZ(adj_, vids);
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for COO graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(fmt == "coo") << "Not valid adj format request.";
    if (transpose) {
      return {aten::HStack(adj_.col, adj_.row)};
    } else {
      return {aten::HStack(adj_.row, adj_.col)};
    }
  }

  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
328
329
330
331
332
333
334
335
336
337
338
339
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
    HeteroSubgraph subg;
    const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
    subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
        submat.row, submat.col);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
340
341
342
343
344
345
346
347
348
349
350
351
352
353
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
    HeteroSubgraph subg;
    if (!preserve_nodes) {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
      subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
      const auto new_nsrc = subg.induced_vertices[0]->shape[0];
      const auto new_ndst = subg.induced_vertices[1]->shape[0];
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
354
          meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
355
356
357
358
359
360
361
      subg.induced_edges = eids;
    } else {
      IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
      IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(0), NumBits(), Context()));
      subg.induced_vertices.emplace_back(aten::Range(0, NumVertices(1), NumBits(), Context()));
      subg.graph = std::make_shared<COO>(
Minjie Wang's avatar
Minjie Wang committed
362
          meta_graph(), NumVertices(0), NumVertices(1), new_src, new_dst);
363
364
365
366
367
368
369
370
371
      subg.induced_edges = eids;
    }
    return subg;
  }

  aten::COOMatrix adj() const {
    return adj_;
  }

372
373
374
375
376
377
378
379
380
  /*!
   * \brief Determines whether the graph is "hypersparse", i.e. having significantly more
   * nodes than edges.
   */
  bool IsHypersparse() const {
    return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
           (NumVertices(SrcType()) > 1000000);
  }

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
 private:
  /*! \brief internal adjacency matrix. Data array is empty */
  aten::COOMatrix adj_;

  /*! \brief multi-graph flag */
  Lazy<bool> is_multigraph_;
};

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

/*! \brief CSR graph */
Minjie Wang's avatar
Minjie Wang committed
396
class UnitGraph::CSR : public BaseHeteroGraph {
397
 public:
Minjie Wang's avatar
Minjie Wang committed
398
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
399
      IdArray indptr, IdArray indices, IdArray edge_ids)
Minjie Wang's avatar
Minjie Wang committed
400
    : BaseHeteroGraph(metagraph) {
401
402
403
404
405
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
    CHECK(aten::IsValidIdArray(edge_ids));
    CHECK_EQ(indices->shape[0], edge_ids->shape[0])
      << "indices and edge id arrays should have the same length";
406
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
407
    sorted_ = false;
408
409
  }

Minjie Wang's avatar
Minjie Wang committed
410
  CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
411
      IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
Minjie Wang's avatar
Minjie Wang committed
412
    : BaseHeteroGraph(metagraph), is_multigraph_(is_multigraph) {
413
414
415
416
417
    CHECK(aten::IsValidIdArray(indptr));
    CHECK(aten::IsValidIdArray(indices));
    CHECK(aten::IsValidIdArray(edge_ids));
    CHECK_EQ(indices->shape[0], edge_ids->shape[0])
      << "indices and edge id arrays should have the same length";
418
    adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
419
    sorted_ = false;
420
421
  }

Minjie Wang's avatar
Minjie Wang committed
422
  explicit CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
Da Zheng's avatar
Da Zheng committed
423
424
425
    : BaseHeteroGraph(metagraph), adj_(csr) {
    sorted_ = false;
  }
426

Minjie Wang's avatar
Minjie Wang committed
427
428
  inline dgl_type_t SrcType() const {
    return 0;
429
  }
Minjie Wang's avatar
Minjie Wang committed
430
431
432
433
434
435
436

  inline dgl_type_t DstType() const {
    return NumVertexTypes() == 1? 0 : 1;
  }

  inline dgl_type_t EdgeType() const {
    return 0;
437
438
439
  }

  HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
Minjie Wang's avatar
Minjie Wang committed
440
    LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
441
442
443
444
445
      << "The relation graph is simply this graph itself.";
    return {};
  }

  void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
Minjie Wang's avatar
Minjie Wang committed
446
    LOG(FATAL) << "UnitGraph graph is not mutable.";
447
448
449
  }

  void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
Minjie Wang's avatar
Minjie Wang committed
450
    LOG(FATAL) << "UnitGraph graph is not mutable.";
451
452
453
  }

  void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
Minjie Wang's avatar
Minjie Wang committed
454
    LOG(FATAL) << "UnitGraph graph is not mutable.";
455
456
457
  }

  void Clear() override {
Minjie Wang's avatar
Minjie Wang committed
458
    LOG(FATAL) << "UnitGraph graph is not mutable.";
459
460
  }

461
462
463
464
  DLDataType DataType() const override {
    return adj_.indices->dtype;
  }

465
466
467
468
469
470
471
472
  DLContext Context() const override {
    return adj_.indices->ctx;
  }

  uint8_t NumBits() const override {
    return adj_.indices->dtype.bits;
  }

473
474
475
476
477
  CSR AsNumBits(uint8_t bits) const {
    if (NumBits() == bits) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
478
          meta_graph_,
479
480
481
482
483
484
485
486
487
488
489
490
491
492
          adj_.num_rows, adj_.num_cols,
          aten::AsNumBits(adj_.indptr, bits),
          aten::AsNumBits(adj_.indices, bits),
          aten::AsNumBits(adj_.data, bits));
      ret.is_multigraph_ = is_multigraph_;
      return ret;
    }
  }

  CSR CopyTo(const DLContext& ctx) const {
    if (Context() == ctx) {
      return *this;
    } else {
      CSR ret(
Minjie Wang's avatar
Minjie Wang committed
493
          meta_graph_,
494
495
496
497
498
499
500
501
502
          adj_.num_rows, adj_.num_cols,
          adj_.indptr.CopyTo(ctx),
          adj_.indices.CopyTo(ctx),
          adj_.data.CopyTo(ctx));
      ret.is_multigraph_ = is_multigraph_;
      return ret;
    }
  }

503
504
505
506
507
508
509
510
511
512
513
  bool IsMultigraph() const override {
    return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
        return aten::CSRHasDuplicate(adj_);
      });
  }

  bool IsReadonly() const override {
    return true;
  }

  uint64_t NumVertices(dgl_type_t vtype) const override {
Minjie Wang's avatar
Minjie Wang committed
514
    if (vtype == SrcType()) {
515
      return adj_.num_rows;
Minjie Wang's avatar
Minjie Wang committed
516
    } else if (vtype == DstType()) {
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
      return adj_.num_cols;
    } else {
      LOG(FATAL) << "Invalid vertex type: " << vtype;
      return 0;
    }
  }

  uint64_t NumEdges(dgl_type_t etype) const override {
    return adj_.indices->shape[0];
  }

  bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
    return vid < NumVertices(vtype);
  }

  BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
    LOG(FATAL) << "Not enabled for COO graph";
    return {};
  }

  bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
538
539
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
540
541
542
543
    return aten::CSRIsNonZero(adj_, src, dst);
  }

  BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
544
545
    CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
546
547
548
549
550
551
552
553
554
    return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
  }

  IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
Minjie Wang's avatar
Minjie Wang committed
555
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
556
557
558
559
    return aten::CSRGetRowColumnIndices(adj_, src);
  }

  IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
Minjie Wang's avatar
Minjie Wang committed
560
561
    CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
    CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
562
563
564
565
    return aten::CSRGetData(adj_, src, dst);
  }

  EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const override {
566
567
    CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
    return EdgeArray{arrs[0], arrs[1], arrs[2]};
  }

  std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
593
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
594
595
596
597
598
599
600
    IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
    IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
    IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
    return EdgeArray{ret_src, ret_dst, ret_eid};
  }

  EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
601
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    auto csrsubmat = aten::CSRSliceRows(adj_, vids);
    auto coosubmat = aten::CSRToCOO(csrsubmat, false);
    // Note that the row id in the csr submat is relabled, so
    // we need to recover it using an index select.
    auto row = aten::IndexSelect(vids, coosubmat.row);
    return EdgeArray{row, coosubmat.col, coosubmat.data};
  }

  EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
    CHECK(order.empty() || order == std::string("srcdst"))
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
    const auto& coo = aten::CSRToCOO(adj_, false);
    return EdgeArray{coo.row, coo.col, coo.data};
  }

  uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
Minjie Wang's avatar
Minjie Wang committed
629
    CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
630
631
632
633
    return aten::CSRGetRowNNZ(adj_, vid);
  }

  DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
634
    CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
    return aten::CSRGetRowNNZ(adj_, vids);
  }

  DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(indices_data + start, indices_data + end);
  }

  DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    // TODO(minjie): This still assumes the data type and device context
    //   of this graph. Should fix later.
    const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
    const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
    const dgl_id_t start = indptr_data[vid];
    const dgl_id_t end = indptr_data[vid + 1];
    return DGLIdIters(eid_data + start, eid_data + end);
  }

  DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  std::vector<IdArray> GetAdj(
      dgl_type_t etype, bool transpose, const std::string &fmt) const override {
    CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
    return {adj_.indptr, adj_.indices, adj_.data};
  }

  HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
Minjie Wang's avatar
Minjie Wang committed
675
676
677
678
    CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
    auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
    CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
    CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
679
    HeteroSubgraph subg;
Minjie Wang's avatar
Minjie Wang committed
680
    const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
681
    IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
Minjie Wang's avatar
Minjie Wang committed
682
    subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        submat.indptr, submat.indices, sub_eids);
    subg.induced_vertices = vids;
    subg.induced_edges.emplace_back(submat.data);
    return subg;
  }

  HeteroSubgraph EdgeSubgraph(
      const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
    LOG(INFO) << "Not enabled for CSR graph.";
    return {};
  }

  aten::CSRMatrix adj() const {
    return adj_;
  }

 private:
  /*! \brief internal adjacency matrix. Data array stores edge ids */
  aten::CSRMatrix adj_;

  /*! \brief multi-graph flag */
  Lazy<bool> is_multigraph_;
Da Zheng's avatar
Da Zheng committed
705
706
707

  /*! \brief indicate that the edges are stored in the sorted order. */
  bool sorted_;
708
709
710
711
};

//////////////////////////////////////////////////////////
//
Minjie Wang's avatar
Minjie Wang committed
712
// unit graph implementation
713
714
715
//
//////////////////////////////////////////////////////////

716
717
718
719
DLDataType UnitGraph::DataType() const {
  return GetAny()->DataType();
}

Minjie Wang's avatar
Minjie Wang committed
720
DLContext UnitGraph::Context() const {
721
722
723
  return GetAny()->Context();
}

Minjie Wang's avatar
Minjie Wang committed
724
uint8_t UnitGraph::NumBits() const {
725
726
727
  return GetAny()->NumBits();
}

Minjie Wang's avatar
Minjie Wang committed
728
bool UnitGraph::IsMultigraph() const {
729
730
731
  return GetAny()->IsMultigraph();
}

Minjie Wang's avatar
Minjie Wang committed
732
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
733
734
735
736
737
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  // TODO(BarclayII): we have a lot of special handling for CSC.
  // Need to have a UnitGraph::CSC backend instead.
  if (fmt == SparseFormat::CSC)
Minjie Wang's avatar
Minjie Wang committed
738
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
739
  return ptr->NumVertices(vtype);
740
741
}

Minjie Wang's avatar
Minjie Wang committed
742
uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
743
744
745
  return GetAny()->NumEdges(etype);
}

Minjie Wang's avatar
Minjie Wang committed
746
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
747
748
749
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
Minjie Wang's avatar
Minjie Wang committed
750
    vtype = (vtype == SrcType()) ? DstType() : SrcType();
751
  return ptr->HasVertex(vtype, vid);
752
753
}

Minjie Wang's avatar
Minjie Wang committed
754
BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
755
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
756
757
758
  return aten::LT(vids, NumVertices(vtype));
}

Minjie Wang's avatar
Minjie Wang committed
759
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
760
761
762
763
764
765
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->HasEdgeBetween(etype, dst, src);
  else
    return ptr->HasEdgeBetween(etype, src, dst);
766
767
}

Minjie Wang's avatar
Minjie Wang committed
768
BoolArray UnitGraph::HasEdgesBetween(
769
    dgl_type_t etype, IdArray src, IdArray dst) const {
770
771
772
773
774
775
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->HasEdgesBetween(etype, dst, src);
  else
    return ptr->HasEdgesBetween(etype, src, dst);
776
777
}

Minjie Wang's avatar
Minjie Wang committed
778
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
779
780
781
782
783
784
  const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->Successors(etype, dst);
  else
    return ptr->Predecessors(etype, dst);
785
786
}

Minjie Wang's avatar
Minjie Wang committed
787
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
788
789
790
  const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->Successors(etype, src);
791
792
}

Minjie Wang's avatar
Minjie Wang committed
793
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
794
795
796
797
798
799
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->EdgeId(etype, dst, src);
  else
    return ptr->EdgeId(etype, src, dst);
800
801
}

Minjie Wang's avatar
Minjie Wang committed
802
EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
803
804
805
806
  const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC) {
    EdgeArray edges = ptr->EdgeIds(etype, dst, src);
807
808
    return EdgeArray{edges.dst, edges.src, edges.id};
  } else {
809
    return ptr->EdgeIds(etype, src, dst);
810
811
812
  }
}

Minjie Wang's avatar
Minjie Wang committed
813
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
814
815
816
  const SparseFormat fmt = SelectFormat(SparseFormat::COO);
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdge(etype, eid);
817
818
}

Minjie Wang's avatar
Minjie Wang committed
819
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
820
821
822
  const SparseFormat fmt = SelectFormat(SparseFormat::COO);
  const auto ptr = GetFormat(fmt);
  return ptr->FindEdges(etype, eids);
823
824
}

Minjie Wang's avatar
Minjie Wang committed
825
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
826
827
828
829
830
831
832
833
  const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC) {
    const EdgeArray& ret = ptr->OutEdges(etype, vid);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vid);
  }
834
835
}

Minjie Wang's avatar
Minjie Wang committed
836
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
837
838
839
840
841
842
843
844
  const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC) {
    const EdgeArray& ret = ptr->OutEdges(etype, vids);
    return {ret.dst, ret.src, ret.id};
  } else {
    return ptr->InEdges(etype, vids);
  }
845
846
}

Minjie Wang's avatar
Minjie Wang committed
847
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
848
849
850
  const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vid);
851
852
}

Minjie Wang's avatar
Minjie Wang committed
853
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
854
855
856
  const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdges(etype, vids);
857
858
}

Minjie Wang's avatar
Minjie Wang committed
859
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
860
861
862
863
  SparseFormat fmt;
  if (order == std::string("eid")) {
    fmt = SelectFormat(SparseFormat::COO);
  } else if (order.empty()) {
864
    // arbitrary order
865
    fmt = SelectFormat(SparseFormat::ANY);
866
  } else if (order == std::string("srcdst")) {
867
    fmt = SelectFormat(SparseFormat::CSR);
868
869
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
870
    return {};
871
  }
872
873
874
875
876
877

  const auto& edges = GetFormat(fmt)->Edges(etype, order);
  if (fmt == SparseFormat::CSC)
    return EdgeArray{edges.dst, edges.src, edges.id};
  else
    return edges;
878
879
}

Minjie Wang's avatar
Minjie Wang committed
880
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
881
882
883
884
885
886
  SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->OutDegree(etype, vid);
  else
    return ptr->InDegree(etype, vid);
887
888
}

Minjie Wang's avatar
Minjie Wang committed
889
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
890
891
892
893
894
895
  SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->OutDegrees(etype, vids);
  else
    return ptr->InDegrees(etype, vids);
896
897
}

Minjie Wang's avatar
Minjie Wang committed
898
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
899
900
901
  SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegree(etype, vid);
902
903
}

Minjie Wang's avatar
Minjie Wang committed
904
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
905
906
907
  SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->OutDegrees(etype, vids);
908
909
}

Minjie Wang's avatar
Minjie Wang committed
910
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
911
912
913
  SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->SuccVec(etype, vid);
914
915
}

Minjie Wang's avatar
Minjie Wang committed
916
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
917
918
919
  SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  const auto ptr = GetFormat(fmt);
  return ptr->OutEdgeVec(etype, vid);
920
921
}

Minjie Wang's avatar
Minjie Wang committed
922
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
923
924
925
926
927
928
  SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->SuccVec(etype, vid);
  else
    return ptr->PredVec(etype, vid);
929
930
}

Minjie Wang's avatar
Minjie Wang committed
931
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
932
933
934
935
936
937
  SparseFormat fmt = SelectFormat(SparseFormat::CSC);
  const auto ptr = GetFormat(fmt);
  if (fmt == SparseFormat::CSC)
    return ptr->OutEdgeVec(etype, vid);
  else
    return ptr->InEdgeVec(etype, vid);
938
939
}

Minjie Wang's avatar
Minjie Wang committed
940
std::vector<IdArray> UnitGraph::GetAdj(
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
    dgl_type_t etype, bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(etype, false, "csr")
      : GetInCSR()->GetAdj(etype, false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(etype, !transpose, fmt);
  } else {
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
  }
}

Minjie Wang's avatar
Minjie Wang committed
960
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
961
  // We prefer to generate a subgraph from out-csr.
962
963
  SparseFormat fmt = SelectFormat(SparseFormat::CSR);
  HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
964
965
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
  HeteroSubgraph ret;
Minjie Wang's avatar
Minjie Wang committed
966
  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, subcsr, nullptr));
967
968
969
970
971
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
972
HeteroSubgraph UnitGraph::EdgeSubgraph(
973
    const std::vector<IdArray>& eids, bool preserve_nodes) const {
974
975
  SparseFormat fmt = SelectFormat(SparseFormat::COO);
  auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
976
977
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
  HeteroSubgraph ret;
Minjie Wang's avatar
Minjie Wang committed
978
  ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), nullptr, nullptr, subcoo));
979
980
981
982
983
  ret.induced_vertices = std::move(sg.induced_vertices);
  ret.induced_edges = std::move(sg.induced_edges);
  return ret;
}

Minjie Wang's avatar
Minjie Wang committed
984
HeteroGraphPtr UnitGraph::CreateFromCOO(
985
986
    int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row, IdArray col,
    SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
987
988
989
990
991
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  COOPtr coo(new COO(mg, num_src, num_dst, row, col));
992
993
994

  return HeteroGraphPtr(
      new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
995
996
}

Minjie Wang's avatar
Minjie Wang committed
997
998
HeteroGraphPtr UnitGraph::CreateFromCSR(
    int64_t num_vtypes, int64_t num_src, int64_t num_dst,
999
    IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
Minjie Wang's avatar
Minjie Wang committed
1000
1001
1002
1003
1004
  CHECK(num_vtypes == 1 || num_vtypes == 2);
  if (num_vtypes == 1)
    CHECK_EQ(num_src, num_dst);
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
  CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
1005
  return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
1006
1007
}

Minjie Wang's avatar
Minjie Wang committed
1008
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
1009
1010
1011
1012
1013
1014
1015
  if (g->NumBits() == bits) {
    return g;
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1016
    auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1017
    CHECK_NOTNULL(bg);
1018

1019
1020
    CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->AsNumBits(bits)));
1021
1022
    return HeteroGraphPtr(
        new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1023
1024
1025
  }
}

Minjie Wang's avatar
Minjie Wang committed
1026
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
1027
1028
1029
1030
1031
1032
1033
  if (ctx == g->Context()) {
    return g;
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
Minjie Wang's avatar
Minjie Wang committed
1034
  auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
1035
  CHECK_NOTNULL(bg);
1036

1037
1038
  CSRPtr new_incsr = CSRPtr(new CSR(bg->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(bg->GetOutCSR()->CopyTo(ctx)));
1039
1040
  return HeteroGraphPtr(
      new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, nullptr, bg->restrict_format_));
1041
1042
}

1043
1044
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
                     SparseFormat restrict_format)
Minjie Wang's avatar
Minjie Wang committed
1045
  : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
  restrict_format_ = restrict_format;

  // If the graph is hypersparse and in COO format, switch the restricted format to COO.
  // If the graph is given as CSR, the indptr array is already materialized so we don't
  // care about restricting conversion anyway (even if it is hypersparse).
  if (restrict_format == SparseFormat::ANY) {
    if (coo && coo->IsHypersparse())
      restrict_format_ = SparseFormat::COO;
  }

1056
1057
1058
  CHECK(GetAny()) << "At least one graph structure should exist.";
}

Minjie Wang's avatar
Minjie Wang committed
1059
UnitGraph::CSRPtr UnitGraph::GetInCSR() const {
1060
1061
1062
  if (!in_csr_) {
    if (out_csr_) {
      const auto& newadj = aten::CSRTranspose(out_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1063
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1064
1065
1066
1067
1068
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& adj = coo_->adj();
      const auto& newadj = aten::COOToCSR(
          aten::COOMatrix{adj.num_cols, adj.num_rows, adj.col, adj.row});
Minjie Wang's avatar
Minjie Wang committed
1069
      const_cast<UnitGraph*>(this)->in_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1070
1071
1072
1073
1074
1075
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
Minjie Wang's avatar
Minjie Wang committed
1076
UnitGraph::CSRPtr UnitGraph::GetOutCSR() const {
1077
1078
1079
  if (!out_csr_) {
    if (in_csr_) {
      const auto& newadj = aten::CSRTranspose(in_csr_->adj());
Minjie Wang's avatar
Minjie Wang committed
1080
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1081
1082
1083
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const auto& newadj = aten::COOToCSR(coo_->adj());
Minjie Wang's avatar
Minjie Wang committed
1084
      const_cast<UnitGraph*>(this)->out_csr_ = std::make_shared<CSR>(meta_graph(), newadj);
1085
1086
1087
1088
1089
1090
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
Minjie Wang's avatar
Minjie Wang committed
1091
UnitGraph::COOPtr UnitGraph::GetCOO() const {
1092
1093
1094
  if (!coo_) {
    if (in_csr_) {
      const auto& newadj = aten::CSRToCOO(in_csr_->adj(), true);
Minjie Wang's avatar
Minjie Wang committed
1095
1096
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(
          meta_graph(),
1097
1098
1099
1100
          aten::COOMatrix{newadj.num_cols, newadj.num_rows, newadj.col, newadj.row});
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
Minjie Wang's avatar
Minjie Wang committed
1101
      const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
1102
1103
1104
1105
1106
    }
  }
  return coo_;
}

Minjie Wang's avatar
Minjie Wang committed
1107
aten::CSRMatrix UnitGraph::GetInCSRMatrix() const {
1108
1109
1110
  return GetInCSR()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1111
aten::CSRMatrix UnitGraph::GetOutCSRMatrix() const {
1112
1113
1114
  return GetOutCSR()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1115
aten::COOMatrix UnitGraph::GetCOOMatrix() const {
1116
1117
1118
  return GetCOO()->adj();
}

Minjie Wang's avatar
Minjie Wang committed
1119
HeteroGraphPtr UnitGraph::GetAny() const {
1120
1121
1122
1123
1124
1125
1126
1127
1128
  if (in_csr_) {
    return in_csr_;
  } else if (out_csr_) {
    return out_csr_;
  } else {
    return coo_;
  }
}

1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
  switch (format) {
  case SparseFormat::CSR:
    return GetOutCSR();
  case SparseFormat::CSC:
    return GetInCSR();
  case SparseFormat::COO:
    return GetCOO();
  case SparseFormat::ANY:
    return GetAny();
  default:
    LOG(FATAL) << "unsupported format code";
    return nullptr;
  }
}

SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
  if (restrict_format_ != SparseFormat::ANY)
    return restrict_format_;
  else if (preferred_format != SparseFormat::ANY)
    return preferred_format;
  else if (in_csr_)
    return SparseFormat::CSC;
  else if (out_csr_)
    return SparseFormat::CSR;
  else
    return SparseFormat::COO;
}

1158
1159
1160
1161
1162
1163
1164
1165
UnitGraph* UnitGraph::EmptyGraph() {
  auto src = NewIdArray(0);
  auto dst = NewIdArray(0);
  auto mg = CreateUnitGraphMetaGraph(1);
  COOPtr coo(new COO(mg, 0, 0, src, dst));
  return new UnitGraph(mg, nullptr, nullptr, coo);
}

1166
1167
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

1168
// Using OurCSR
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
bool UnitGraph::Load(dmlc::Stream* fs) {
  uint64_t magicNum;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data";
  uint64_t num_vtypes, num_src, num_dst;
  CHECK(fs->Read(&num_vtypes)) << "Invalid num_vtypes";
  CHECK(fs->Read(&num_src)) << "Invalid num_src";
  CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
  aten::CSRMatrix csr_matrix;
  CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
  auto mg = CreateUnitGraphMetaGraph(num_vtypes);
1180
1181
1182
  CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr,
                     csr_matrix.indices, csr_matrix.data));
  *this = UnitGraph(mg, nullptr, csr, nullptr);
1183
1184
1185
  return true;
}

1186
// Using Out CSR
1187
1188
void UnitGraph::Save(dmlc::Stream* fs) const {
  // Following CreateFromCSR signature
1189
  aten::CSRMatrix csr_matrix = GetOutCSRMatrix();
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
  uint64_t num_vtypes = NumVertexTypes();
  uint64_t num_src = NumVertices(SrcType());
  uint64_t num_dst = NumVertices(DstType());
  fs->Write(kDGLSerialize_UnitGraphMagic);
  fs->Write(num_vtypes);
  fs->Write(num_src);
  fs->Write(num_dst);
  fs->Write(csr_matrix);
}

1200
}  // namespace dgl