immutable_graph.cc 23.4 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

7
#include <dgl/packed_func_ext.h>
8
#include <dgl/immutable_graph.h>
9
10
11
12
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
13
14
15

#include "../c_api_common.h"

16
17
using namespace dgl::runtime;

18
namespace dgl {
19
namespace {
20
21
22
23
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

24
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
25
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
26
#ifndef _WIN32
27
28
29
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
30
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
31
32
33
34
35
36
37
38
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
39
#else
40
41
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
42
43
#endif  // _WIN32
}
44
45
46
47
48
49
50
51
52
53
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
  : is_multigraph_(is_multigraph) {
54
55
56
57
58
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
59
60
}

61
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
62
63
64
65
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
66
67
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
68
69
70
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
71
  : is_multigraph_(is_multigraph) {
72
73
74
75
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
76
77
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
78
79
80
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
81
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
82
83
84
85
86
87
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
88
89
90
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
91
      shared_mem_name, num_verts, num_edges, true);
92
  // copy the given data into the shared memory arrays
93
94
95
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
96
97
98
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
99
100
         const std::string &shared_mem_name): is_multigraph_(is_multigraph),
         shared_mem_name_(shared_mem_name) {
101
102
103
104
105
106
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
107
108
109
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
110
      shared_mem_name, num_verts, num_edges, true);
111
  // copy the given data into the shared memory arrays
112
113
114
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
115
116
117
118
}

CSR::CSR(const std::string &shared_mem_name,
         int64_t num_verts, int64_t num_edges, bool is_multigraph)
119
  : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
120
121
122
123
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
124
      shared_mem_name, num_verts, num_edges, false);
125
126
127
128
129
}

bool CSR::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
130
      return aten::CSRHasDuplicate(adj_);
131
132
    });
}
133

134
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
135
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
136
137
138
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
139
  return EdgeArray{ret_src, ret_dst, ret_eid};
140
141
}

142
EdgeArray CSR::OutEdges(IdArray vids) const {
143
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
144
145
146
147
148
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
149
  return EdgeArray{row, coosubmat.col, coosubmat.data};
150
151
}

152
DegreeArray CSR::OutDegrees(IdArray vids) const {
153
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
154
  return aten::CSRGetRowNNZ(adj_, vids);
155
156
}

157
158
159
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
160
161
162
163
164
165
166
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
  CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
167
168
}

169
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
170
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
171
  CHECK(radius == 1) << "invalid radius: " << radius;
172
  return aten::CSRGetRowColumnIndices(adj_, vid);
173
174
}

175
176
177
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
178
  return aten::CSRGetData(adj_, src, dst);
179
180
}

181
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
182
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
183
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
184
}
185

186
EdgeArray CSR::Edges(const std::string &order) const {
187
  CHECK(order.empty() || order == std::string("srcdst"))
188
    << "CSR only support Edges of order \"srcdst\","
189
    << " but got \"" << order << "\".";
190
  const auto& coo = aten::CSRToCOO(adj_, false);
191
  return EdgeArray{coo.row, coo.col, coo.data};
192
193
}

194
195
Subgraph CSR::VertexSubgraph(IdArray vids) const {
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
196
197
198
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
199
200
201
202
203
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
204
205
206
}

CSRPtr CSR::Transpose() const {
207
208
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
209
210
211
}

COOPtr CSR::ToCOO() const {
212
213
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
214
215
}

216
217
218
219
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
220
221
222
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
223
224
225
226
227
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

228
229
230
231
232
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
233
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
234
235
236
  }
}

237
238
239
240
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
241
242
243
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
244
245
246
247
248
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

269
270
271
272
273
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
274
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
275
276
277
  CHECK(IsValidIdArray(src));
  CHECK(IsValidIdArray(dst));
  CHECK_EQ(src->shape[0], dst->shape[0]);
278
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
279
280
281
}

COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
282
  : is_multigraph_(is_multigraph) {
283
284
285
  CHECK(IsValidIdArray(src));
  CHECK(IsValidIdArray(dst));
  CHECK_EQ(src->shape[0], dst->shape[0]);
286
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
287
288
289
290
291
}

bool COO::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
292
      return aten::COOHasDuplicate(adj_);
293
    });
294
295
}

296
297
298
299
300
301
302
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
  const auto src = aten::IndexSelect(adj_.row, eid);
  const auto dst = aten::IndexSelect(adj_.col, eid);
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

303
EdgeArray COO::FindEdges(IdArray eids) const {
304
  CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
305
306
307
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
308
309
}

310
EdgeArray COO::Edges(const std::string &order) const {
311
312
313
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
314
315
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
316
317
}

318
319
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
320
321
  COOPtr subcoo;
  IdArray induced_nodes;
322
  if (!preserve_nodes) {
323
324
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
325
    induced_nodes = aten::Relabel_({new_src, new_dst});
326
    const auto new_nnodes = induced_nodes->shape[0];
Da Zheng's avatar
Da Zheng committed
327
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
328
  } else {
329
330
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
331
332
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
333
  }
Da Zheng's avatar
Da Zheng committed
334
335
336
337
338
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
339
340
}

341
CSRPtr COO::ToCSR() const {
342
343
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
344
345
}

346
347
348
349
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
350
351
352
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
353
354
355
356
357
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

358
359
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
360
  return COO();
361
362
}

363
364
365
366
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
367
368
369
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
370
371
372
373
374
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

375
376
377
378
379
380
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
  CHECK(IsValidIdArray(vids)) << "Invalid id array input";
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

430
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
450
451
}

452
453
454
455
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
456
457
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
458
459
}

460
461
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
462
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
463
464
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
465
466
467
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
468
469
470
471
472
473
474
475
476
477
478
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
479
  } else {
480
481
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
482
483
484
  }
}

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
    CSRPtr csr(new CSR(indptr, indices, edge_ids));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
                     GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    const std::string &shared_mem_name, size_t num_vertices,
    size_t num_edges, bool multigraph,
    const std::string &edge_dir) {
  CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
                     multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst) {
  COOPtr coo(new COO(num_vertices, src, dst));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
  COOPtr coo(new COO(num_vertices, src, dst, multigraph));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
573
  if (ig) {
574
    return ig;
575
576
577
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
578
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
579
580
581
  }
}

582
583
584
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
585
586
587
588
589
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
590
591
592
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
593
594
}

595
596
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
    const std::string &edge_dir, const std::string &name) {
597
598
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
599
  if (edge_dir == std::string("in"))
600
    new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
601
  else if (edge_dir == std::string("out"))
602
603
    new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
604
605
}

606
607
608
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
609
610
611
612
613
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
614
615
616
617
618
619
620
621
622
623
624
625
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
626
627
628
  }
}

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    std::string edge_dir = args[1];
    std::string name = args[2];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

658
}  // namespace dgl