immutable_graph.cc 25.6 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

#include <dgl/immutable_graph.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
10
#include <dgl/base_heterograph.h>
11
12
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
13
14
15
16
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
17
18

#include "../c_api_common.h"
19
20
#include "heterograph.h"
#include "unit_graph.h"
21

22
23
using namespace dgl::runtime;

24
namespace dgl {
25
namespace {
26
27
28
29
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

30
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
31
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
32
#ifndef _WIN32
33
34
35
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
36
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
37
38
39
40
41
42
43
44
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
45
#else
46
47
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
48
49
#endif  // _WIN32
}
50
51
52
53
54
55
56
57
58
59
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
  : is_multigraph_(is_multigraph) {
60
61
62
63
64
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
65
  adj_.sorted = false;
66
67
}

68
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
69
70
71
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
72
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
73
74
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
75
  adj_.sorted = false;
76
77
78
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
79
  : is_multigraph_(is_multigraph) {
80
81
82
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
83
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
84
85
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
86
  adj_.sorted = false;
87
88
89
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
90
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
91
92
93
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
94
95
96
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
97
98
99
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
100
      shared_mem_name, num_verts, num_edges, true);
101
  // copy the given data into the shared memory arrays
102
103
104
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
105
  adj_.sorted = false;
106
107
108
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
109
110
         const std::string &shared_mem_name): is_multigraph_(is_multigraph),
         shared_mem_name_(shared_mem_name) {
111
112
113
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
114
115
116
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
117
118
119
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
120
      shared_mem_name, num_verts, num_edges, true);
121
  // copy the given data into the shared memory arrays
122
123
124
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
125
  adj_.sorted = false;
126
127
128
129
}

CSR::CSR(const std::string &shared_mem_name,
         int64_t num_verts, int64_t num_edges, bool is_multigraph)
130
  : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
131
132
133
134
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
135
      shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
136
  adj_.sorted = false;
137
138
139
140
141
}

bool CSR::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
142
      return aten::CSRHasDuplicate(adj_);
143
144
    });
}
145

146
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
147
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
148
149
150
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
151
  return EdgeArray{ret_src, ret_dst, ret_eid};
152
153
}

154
EdgeArray CSR::OutEdges(IdArray vids) const {
155
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
156
157
158
159
160
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
161
  return EdgeArray{row, coosubmat.col, coosubmat.data};
162
163
}

164
DegreeArray CSR::OutDegrees(IdArray vids) const {
165
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
166
  return aten::CSRGetRowNNZ(adj_, vids);
167
168
}

169
170
171
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
172
173
174
175
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
176
177
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
178
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
179
180
}

181
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
182
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
183
  CHECK(radius == 1) << "invalid radius: " << radius;
184
  return aten::CSRGetRowColumnIndices(adj_, vid);
185
186
}

187
188
189
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
190
  return aten::CSRGetData(adj_, src, dst);
191
192
}

193
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
194
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
195
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
196
}
197

198
EdgeArray CSR::Edges(const std::string &order) const {
199
  CHECK(order.empty() || order == std::string("srcdst"))
200
    << "CSR only support Edges of order \"srcdst\","
201
    << " but got \"" << order << "\".";
202
  const auto& coo = aten::CSRToCOO(adj_, false);
203
  return EdgeArray{coo.row, coo.col, coo.data};
204
205
}

206
Subgraph CSR::VertexSubgraph(IdArray vids) const {
207
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
208
209
210
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
211
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
212
213
214
215
216
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
217
218
219
}

CSRPtr CSR::Transpose() const {
220
221
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
222
223
224
}

COOPtr CSR::ToCOO() const {
225
226
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
227
228
}

229
230
231
232
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
233
234
235
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
236
237
238
239
240
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

241
242
243
244
245
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
246
    // TODO(zhengda) we need to set sorted_ properly.
247
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
248
249
250
  }
}

251
252
253
254
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
255
256
257
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
258
259
260
261
262
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

283
284
285
286
287
288
289
290
291
bool CSR::Load(dmlc::Stream *fs) {
  fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
  return true;
}

void CSR::Save(dmlc::Stream *fs) const {
  fs->Write(adj_);
}

292
293
294
295
296
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
297
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
298
299
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
300
  CHECK_EQ(src->shape[0], dst->shape[0]);
301
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
302
303
304
}

COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
305
  : is_multigraph_(is_multigraph) {
306
307
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
308
  CHECK_EQ(src->shape[0], dst->shape[0]);
309
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
310
311
312
313
314
}

bool COO::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
315
      return aten::COOHasDuplicate(adj_);
316
    });
317
318
}

319
320
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
321
322
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
323
324
325
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

326
EdgeArray COO::FindEdges(IdArray eids) const {
327
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
328
329
330
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
331
332
}

333
EdgeArray COO::Edges(const std::string &order) const {
334
335
336
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
337
338
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
339
340
}

341
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
342
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
343
344
  COOPtr subcoo;
  IdArray induced_nodes;
345
  if (!preserve_nodes) {
346
347
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
348
    induced_nodes = aten::Relabel_({new_src, new_dst});
349
    const auto new_nnodes = induced_nodes->shape[0];
350
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph()));
351
  } else {
352
353
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
354
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
355
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph()));
356
  }
Da Zheng's avatar
Da Zheng committed
357
358
359
360
361
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
362
363
}

364
CSRPtr COO::ToCSR() const {
365
366
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
367
368
}

369
370
371
372
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
373
374
375
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
376
377
378
379
380
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

381
382
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
383
  return COO();
384
385
}

386
387
388
389
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
390
391
392
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
393
394
395
396
397
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

398
399
400
401
402
403
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

404
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
405
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

453
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
473
474
}

475
476
477
478
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
479
480
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
481
482
}

483
484
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
485
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
486
487
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
488
489
490
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
491
492
493
494
495
496
497
498
499
500
501
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
502
  } else {
503
504
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
505
506
507
  }
}

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
    CSRPtr csr(new CSR(indptr, indices, edge_ids));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
                     GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    const std::string &shared_mem_name, size_t num_vertices,
    size_t num_edges, bool multigraph,
    const std::string &edge_dir) {
  CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
                     multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst) {
  COOPtr coo(new COO(num_vertices, src, dst));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
  COOPtr coo(new COO(num_vertices, src, dst, multigraph));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
596
  if (ig) {
597
    return ig;
598
599
600
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
601
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
602
603
604
  }
}

605
606
607
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
608
609
610
611
612
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
613
614
615
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
616
617
}

618
619
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
    const std::string &edge_dir, const std::string &name) {
620
621
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
622
  if (edge_dir == std::string("in"))
623
    new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
624
  else if (edge_dir == std::string("out"))
625
626
    new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
627
628
}

629
630
631
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
632
633
634
635
636
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
637
638
639
640
641
642
643
644
645
646
647
648
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
649
650
651
  }
}

652
653
654
655
656
657
658
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
659
660
661
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
      << "Invalid ImmutableGraph Magic Number";
  CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
662
663
664
665
666
667
  return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
668
  fs->Write(GetOutCSR());
669
670
}

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
  aten::CSRMatrix in_csr, out_csr;
  aten::COOMatrix coo;

  if (in_csr_)
    in_csr = GetInCSR()->ToCSRMatrix();
  if (out_csr_)
    out_csr = GetOutCSR()->ToCSRMatrix();
  if (coo_)
    coo = GetCOO()->ToCOOMatrix();

  auto g = UnitGraph::CreateHomographFrom(
      in_csr, out_csr, coo,
      in_csr_ != nullptr,
      out_csr_ != nullptr,
      coo_ != nullptr);
  return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
    CHECK(ig) << "graph is not readonly";
    *rv = HeteroGraphRef(ig->AsHeteroGraph());
  });

698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    std::string edge_dir = args[1];
    std::string name = args[2];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

727
}  // namespace dgl