immutable_graph.cc 23.4 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2018 by Contributors
3
4
 * @file graph/immutable_graph.cc
 * @brief DGL immutable graph index implementation
5
6
 */

7
#include <dgl/base_heterograph.h>
8
#include <dgl/immutable_graph.h>
9
10
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
11
12
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
13
#include <string.h>
14

15
16
17
#include <bitset>
#include <numeric>
#include <tuple>
18
19

#include "../c_api_common.h"
20
21
#include "heterograph.h"
#include "unit_graph.h"
22

23
24
using namespace dgl::runtime;

25
namespace dgl {
26
namespace {
27
28
inline std::string GetSharedMemName(
    const std::string &name, const std::string &edge_dir) {
29
30
31
  return name + "_" + edge_dir;
}

32
33
34
35
36
37
38
39
40
41
42
43
/*
 * The metadata of a graph index that are needed for shared-memory graph.
 */
struct GraphIndexMetadata {
  int64_t num_nodes;
  int64_t num_edges;
  bool has_in_csr;
  bool has_out_csr;
  bool has_coo;
};

/*
44
45
46
 * Serialize the metadata of a graph index and place it in a shared-memory
 * tensor. In this way, another process can reconstruct a GraphIndex from a
 * shared-memory tensor.
47
48
49
50
51
52
53
54
55
56
 */
NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
#ifndef _WIN32
  GraphIndexMetadata meta;
  meta.num_nodes = gidx->NumVertices();
  meta.num_edges = gidx->NumEdges();
  meta.has_in_csr = gidx->HasInCSR();
  meta.has_out_csr = gidx->HasOutCSR();
  meta.has_coo = false;

57
58
59
  NDArray meta_arr = NDArray::EmptyShared(
      name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
      true);
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  memcpy(meta_arr->data, &meta, sizeof(meta));
  return meta_arr;
#else
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return NDArray();
#endif  // _WIN32
}

/*
 * Deserialize the metadata of a graph index.
 */
GraphIndexMetadata DeserializeMetadata(const std::string &name) {
  GraphIndexMetadata meta;
#ifndef _WIN32
74
75
76
  NDArray meta_arr = NDArray::EmptyShared(
      name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
      false);
77
78
79
80
81
82
83
  memcpy(&meta, meta_arr->data, sizeof(meta));
#else
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
#endif  // _WIN32
  return meta;
}

84
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
85
86
    const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges,
    bool is_create) {
87
#ifndef _WIN32
88
89
90
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
91
92
      shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1},
      DGLContext{kDGLCPU, 0}, is_create);
93
94
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
95
96
97
98
  IdArray indptr =
      sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});
  IdArray indices = sm_array.CreateView(
      {num_edges}, DGLDataType{kDGLInt, 64, 1},
99
      (num_verts + 1) * sizeof(dgl_id_t));
100
101
  IdArray edge_ids = sm_array.CreateView(
      {num_edges}, DGLDataType{kDGLInt, 64, 1},
102
103
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
104
#else
105
106
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
107
108
#endif  // _WIN32
}
109
110
111
112
113
114
115
116
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

117
CSR::CSR(int64_t num_vertices, int64_t num_edges) {
118
  CHECK(!(num_vertices == 0 && num_edges != 0));
119
120
121
  adj_ = aten::CSRMatrix{
      num_vertices, num_vertices, aten::NewIdArray(num_vertices + 1),
      aten::NewIdArray(num_edges), aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
122
  adj_.sorted = false;
123
124
}

125
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
126
127
128
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
129
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
130
131
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
132
  adj_.sorted = false;
133
134
}

135
136
137
138
CSR::CSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &shared_mem_name)
    : shared_mem_name_(shared_mem_name) {
139
140
141
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
142
143
144
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
145
146
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
147
148
  std::tie(adj_.indptr, adj_.indices, adj_.data) =
      MapFromSharedMemory(shared_mem_name, num_verts, num_edges, true);
149
  // copy the given data into the shared memory arrays
150
151
152
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
153
  adj_.sorted = false;
154
155
}

156
157
158
CSR::CSR(
    const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges)
    : shared_mem_name_(shared_mem_name) {
159
160
161
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
162
163
  std::tie(adj_.indptr, adj_.indices, adj_.data) =
      MapFromSharedMemory(shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
164
  adj_.sorted = false;
165
166
}

167
bool CSR::IsMultigraph() const { return aten::CSRHasDuplicate(adj_); }
168

169
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
170
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
171
172
173
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
174
  return EdgeArray{ret_src, ret_dst, ret_eid};
175
176
}

177
EdgeArray CSR::OutEdges(IdArray vids) const {
178
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
179
180
181
182
183
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
184
  return EdgeArray{row, coosubmat.col, coosubmat.data};
185
186
}

187
DegreeArray CSR::OutDegrees(IdArray vids) const {
188
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
189
  return aten::CSRGetRowNNZ(adj_, vids);
190
191
}

192
193
194
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
195
196
197
198
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
199
200
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
201
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
202
203
}

204
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
205
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
206
  CHECK(radius == 1) << "invalid radius: " << radius;
207
  return aten::CSRGetRowColumnIndices(adj_, vid);
208
209
}

210
211
212
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
213
  return aten::CSRGetAllData(adj_, src, dst);
214
215
}

216
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
217
  const auto &arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
218
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
219
}
220

221
EdgeArray CSR::Edges(const std::string &order) const {
222
  CHECK(order.empty() || order == std::string("srcdst"))
223
224
225
      << "CSR only support Edges of order \"srcdst\","
      << " but got \"" << order << "\".";
  const auto &coo = aten::CSRToCOO(adj_, false);
226
  return EdgeArray{coo.row, coo.col, coo.data};
227
228
}

229
Subgraph CSR::VertexSubgraph(IdArray vids) const {
230
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
231
232
233
  const auto &submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids =
      aten::Range(0, submat.data->shape[0], NumBits(), Context());
234
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
235
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
236
237
238
239
240
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
241
242
243
}

CSRPtr CSR::Transpose() const {
244
  const auto &trans = aten::CSRTranspose(adj_);
245
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
246
247
248
}

COOPtr CSR::ToCOO() const {
249
  const auto &coo = aten::CSRToCOO(adj_, true);
250
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
251
252
}

253
CSR CSR::CopyTo(const DGLContext &ctx) const {
254
255
256
  if (Context() == ctx) {
    return *this;
  } else {
257
258
259
    CSR ret(
        adj_.indptr.CopyTo(ctx), adj_.indices.CopyTo(ctx),
        adj_.data.CopyTo(ctx));
260
261
262
263
    return ret;
  }
}

264
265
266
267
268
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
269
    // TODO(zhengda) we need to set sorted_ properly.
270
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
271
272
273
  }
}

274
275
276
277
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
278
279
280
    CSR ret(
        aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indices, bits),
        aten::AsNumBits(adj_.data, bits));
281
282
283
284
    return ret;
  }
}

285
286
287
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
288
289
  const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
  const dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);
290
291
292
293
294
295
296
297
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
298
299
  const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
  const dgl_id_t *eid_data = static_cast<dgl_id_t *>(adj_.data->data);
300
301
302
303
304
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

305
bool CSR::Load(dmlc::Stream *fs) {
306
  fs->Read(const_cast<dgl::aten::CSRMatrix *>(&adj_));
307
308
309
  return true;
}

310
void CSR::Save(dmlc::Stream *fs) const { fs->Write(adj_); }
311

312
313
314
315
316
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
317
318
319
COO::COO(
    int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
    bool col_sorted) {
320
321
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
322
  CHECK_EQ(src->shape[0], dst->shape[0]);
323
324
  adj_ = aten::COOMatrix{num_vertices,      num_vertices, src,       dst,
                         aten::NullArray(), row_sorted,   col_sorted};
325
326
}

327
bool COO::IsMultigraph() const { return aten::COOHasDuplicate(adj_); }
328

329
330
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
331
332
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
333
334
335
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

336
EdgeArray COO::FindEdges(IdArray eids) const {
337
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
338
339
340
341
342
  BUG_IF_FAIL(aten::IsNullArray(adj_.data))
      << "FindEdges requires the internal COO matrix not having EIDs.";
  return EdgeArray{
      aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
      eids};
343
344
}

345
EdgeArray COO::Edges(const std::string &order) const {
346
  CHECK(order.empty() || order == std::string("eid"))
347
348
      << "COO only support Edges of order \"eid\", but got \"" << order
      << "\".";
349
350
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
351
352
}

353
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
354
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
355
356
  COOPtr subcoo;
  IdArray induced_nodes;
357
  if (!preserve_nodes) {
358
359
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
360
    induced_nodes = aten::Relabel_({new_src, new_dst});
361
    const auto new_nnodes = induced_nodes->shape[0];
362
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
363
  } else {
364
365
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
366
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
367
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
368
  }
Da Zheng's avatar
Da Zheng committed
369
370
371
372
373
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
374
375
}

376
CSRPtr COO::ToCSR() const {
377
  const auto &csr = aten::COOToCSR(adj_);
378
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
379
380
}

381
COO COO::CopyTo(const DGLContext &ctx) const {
382
383
384
  if (Context() == ctx) {
    return *this;
  } else {
385
    COO ret(NumVertices(), adj_.row.CopyTo(ctx), adj_.col.CopyTo(ctx));
386
387
388
389
    return ret;
  }
}

390
391
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
392
  return COO();
393
394
}

395
396
397
398
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
399
400
401
    COO ret(
        NumVertices(), aten::AsNumBits(adj_.row, bits),
        aten::AsNumBits(adj_.col, bits));
402
403
404
405
    return ret;
  }
}

406
407
408
409
410
411
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

412
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
413
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
414
415
416
417
418
419
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
420
      const_cast<ImmutableGraph *>(this)->in_csr_ = out_csr_->Transpose();
421
      if (out_csr_->IsSharedMem())
422
423
424
        LOG(WARNING)
            << "We just construct an in-CSR from a shared-memory out CSR. "
            << "It may dramatically increase memory consumption.";
425
426
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
427
      const_cast<ImmutableGraph *>(this)->in_csr_ = coo_->Transpose()->ToCSR();
428
429
430
431
432
433
434
435
436
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
437
      const_cast<ImmutableGraph *>(this)->out_csr_ = in_csr_->Transpose();
438
      if (in_csr_->IsSharedMem())
439
440
441
        LOG(WARNING)
            << "We just construct an out-CSR from a shared-memory in CSR. "
            << "It may dramatically increase memory consumption.";
442
443
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
444
      const_cast<ImmutableGraph *>(this)->out_csr_ = coo_->ToCSR();
445
446
447
448
449
450
451
452
453
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
454
      const_cast<ImmutableGraph *>(this)->coo_ = in_csr_->ToCOO()->Transpose();
455
456
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
457
      const_cast<ImmutableGraph *>(this)->coo_ = out_csr_->ToCOO();
458
459
460
461
462
    }
  }
  return coo_;
}

463
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
464
465
466
467
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
468
      const auto &edges = in_csr_->Edges(order);
469
470
471
472
473
474
475
476
477
478
479
480
481
482
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
483
484
}

485
486
487
488
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
489
490
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
491
492
}

493
494
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
495
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
496
497
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
498
499
}

500
501
502
503
504
505
506
507
508
509
std::vector<IdArray> ImmutableGraph::GetAdj(
    bool transpose, const std::string &fmt) const {
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst
  // nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example,
  //   transpose=False is equal to in edge CSR. We have this behavior because
  //   previously we use framework's SPMM and we don't cache reverse adj. This
  //   is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
  //   change the behavior and make row for src and col for dst.
510
  if (fmt == std::string("csr")) {
511
512
    return transpose ? GetOutCSR()->GetAdj(false, "csr")
                     : GetInCSR()->GetAdj(false, "csr");
513
514
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
515
  } else {
516
517
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
518
519
520
  }
}

521
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
522
523
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir) {
524
  CSRPtr csr(new CSR(indptr, indices, edge_ids));
525
526
527
528
529
530
531
532
533
534
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

535
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {
536
537
538
539
540
541
  // If the shared memory graph index doesn't exist, we return null directly.
#ifndef _WIN32
  if (!SharedMemory::Exist(GetSharedMemName(name, "meta"))) {
    return nullptr;
  }
#endif  // _WIN32
542
543
544
  GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta"));
  CSRPtr in_csr, out_csr;
  if (meta.has_in_csr) {
545
546
    in_csr = CSRPtr(
        new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
547
  }
548
  if (meta.has_out_csr) {
549
550
    out_csr = CSRPtr(
        new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
551
  }
552
  return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));
553
554
555
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
556
557
    int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
    bool col_sorted) {
558
  COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted));
559
560
561
562
563
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
564
  if (ig) {
565
    return ig;
566
  } else {
567
    const auto &adj = graph->GetAdj(true, "csr");
568
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
569
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
570
571
572
  }
}

573
574
ImmutableGraphPtr ImmutableGraph::CopyTo(
    ImmutableGraphPtr g, const DGLContext &ctx) {
575
576
  if (ctx == g->Context()) {
    return g;
577
578
579
580
581
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
582
583
584
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
585
586
}

587
588
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(
    ImmutableGraphPtr g, const std::string &name) {
589
  CSRPtr new_incsr, new_outcsr;
590
591
592
593
  std::string shared_mem_name = GetSharedMemName(name, "in");
  new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));

  shared_mem_name = GetSharedMemName(name, "out");
594
595
  new_outcsr =
      CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
596

597
598
599
600
  auto new_g =
      ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
  new_g->serialized_shared_meta_ =
      SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
601
  return new_g;
602
603
}

604
605
606
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
607
608
609
610
611
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
612
613
614
615
616
617
618
619
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
620
621
    return ImmutableGraphPtr(
        new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));
622
623
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
624
625
626
  }
}

627
628
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

629
/*! @return Load HeteroGraph from stream, using OutCSR Matrix*/
630
631
632
633
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
634
635
636
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
      << "Invalid ImmutableGraph Magic Number";
  CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
637
638
639
  return true;
}

640
/*! @return Save HeteroGraph to stream, using OutCSR Matrix */
641
642
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
643
  fs->Write(GetOutCSR());
644
645
}

646
647
648
649
HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
  aten::CSRMatrix in_csr, out_csr;
  aten::COOMatrix coo;

650
651
652
  if (in_csr_) in_csr = GetInCSR()->ToCSRMatrix();
  if (out_csr_) out_csr = GetOutCSR()->ToCSRMatrix();
  if (coo_) coo = GetCOO()->ToCOOMatrix();
653

654
  auto g = UnitGraph::CreateUnitGraphFrom(
655
      1, in_csr, out_csr, coo, in_csr_ != nullptr, out_csr_ != nullptr,
656
657
658
659
660
      coo_ != nullptr);
  return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
661
662
663
664
665
666
667
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef g = args[0];
      ImmutableGraphPtr ig =
          std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
      CHECK(ig) << "graph is not readonly";
      *rv = HeteroGraphRef(ig->AsHeteroGraph());
    });
668

669
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
670
671
672
673
674
675
676
677
678
679
680
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef g = args[0];
      const int device_type = args[1];
      const int device_id = args[2];
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(device_type);
      ctx.device_id = device_id;
      ImmutableGraphPtr ig =
          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
      *rv = ImmutableGraph::CopyTo(ig, ctx);
    });
681
682

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
683
684
685
686
687
688
689
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef g = args[0];
      std::string name = args[1];
      ImmutableGraphPtr ig =
          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
      *rv = ImmutableGraph::CopyToSharedMem(ig, name);
    });
690
691

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
692
693
694
695
696
697
698
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      GraphRef g = args[0];
      int bits = args[1];
      ImmutableGraphPtr ig =
          CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
      *rv = ImmutableGraph::AsNumBits(ig, bits);
    });
699

700
}  // namespace dgl