immutable_graph.cc 23.3 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

#include <dgl/immutable_graph.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
10
#include <dgl/base_heterograph.h>
11
12
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
13
14
15
16
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
17
18

#include "../c_api_common.h"
19
20
#include "heterograph.h"
#include "unit_graph.h"
21

22
23
using namespace dgl::runtime;

24
namespace dgl {
25
namespace {
26
27
28
29
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/*
 * The metadata of a graph index that are needed for shared-memory graph.
 */
struct GraphIndexMetadata {
  int64_t num_nodes;
  int64_t num_edges;
  bool has_in_csr;
  bool has_out_csr;
  bool has_coo;
};

/*
 * Serialize the metadata of a graph index and place it in a shared-memory tensor.
 * In this way, another process can reconstruct a GraphIndex from a shared-memory tensor.
 */
NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
#ifndef _WIN32
  GraphIndexMetadata meta;
  meta.num_nodes = gidx->NumVertices();
  meta.num_edges = gidx->NumEdges();
  meta.has_in_csr = gidx->HasInCSR();
  meta.has_out_csr = gidx->HasOutCSR();
  meta.has_coo = false;

  NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1},
                                          DLContext{kDLCPU, 0}, true);
  memcpy(meta_arr->data, &meta, sizeof(meta));
  return meta_arr;
#else
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return NDArray();
#endif  // _WIN32
}

/*
 * Deserialize the metadata of a graph index.
 */
GraphIndexMetadata DeserializeMetadata(const std::string &name) {
  GraphIndexMetadata meta;
#ifndef _WIN32
  NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1},
                                          DLContext{kDLCPU, 0}, false);
  memcpy(&meta, meta_arr->data, sizeof(meta));
#else
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
#endif  // _WIN32
  return meta;
}

79
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
80
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
81
#ifndef _WIN32
82
83
84
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
85
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
86
87
88
89
90
91
92
93
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
94
#else
95
96
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
97
98
#endif  // _WIN32
}
99
100
101
102
103
104
105
106
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

107
CSR::CSR(int64_t num_vertices, int64_t num_edges) {
108
109
110
111
112
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
113
  adj_.sorted = false;
114
115
}

116
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
117
118
119
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
120
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
121
122
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
123
  adj_.sorted = false;
124
125
126
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
127
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
128
129
130
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
131
132
133
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
134
135
136
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
137
      shared_mem_name, num_verts, num_edges, true);
138
  // copy the given data into the shared memory arrays
139
140
141
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
142
  adj_.sorted = false;
143
144
145
}

CSR::CSR(const std::string &shared_mem_name,
146
         int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) {
147
148
149
150
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
151
      shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
152
  adj_.sorted = false;
153
154
155
}

bool CSR::IsMultigraph() const {
156
  return aten::CSRHasDuplicate(adj_);
157
}
158

159
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
160
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
161
162
163
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
164
  return EdgeArray{ret_src, ret_dst, ret_eid};
165
166
}

167
EdgeArray CSR::OutEdges(IdArray vids) const {
168
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
169
170
171
172
173
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
174
  return EdgeArray{row, coosubmat.col, coosubmat.data};
175
176
}

177
DegreeArray CSR::OutDegrees(IdArray vids) const {
178
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
179
  return aten::CSRGetRowNNZ(adj_, vids);
180
181
}

182
183
184
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
185
186
187
188
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
189
190
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
191
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
192
193
}

194
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
195
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
196
  CHECK(radius == 1) << "invalid radius: " << radius;
197
  return aten::CSRGetRowColumnIndices(adj_, vid);
198
199
}

200
201
202
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
203
  return aten::CSRGetAllData(adj_, src, dst);
204
205
}

206
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
207
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
208
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
209
}
210

211
EdgeArray CSR::Edges(const std::string &order) const {
212
  CHECK(order.empty() || order == std::string("srcdst"))
213
    << "CSR only support Edges of order \"srcdst\","
214
    << " but got \"" << order << "\".";
215
  const auto& coo = aten::CSRToCOO(adj_, false);
216
  return EdgeArray{coo.row, coo.col, coo.data};
217
218
}

219
Subgraph CSR::VertexSubgraph(IdArray vids) const {
220
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
221
222
223
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
224
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
225
226
227
228
229
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
230
231
232
}

CSRPtr CSR::Transpose() const {
233
234
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
235
236
237
}

COOPtr CSR::ToCOO() const {
238
239
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
240
241
}

242
243
244
245
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
246
247
248
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
249
250
251
252
    return ret;
  }
}

253
254
255
256
257
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
258
    // TODO(zhengda) we need to set sorted_ properly.
259
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
260
261
262
  }
}

263
264
265
266
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
267
268
269
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
270
271
272
273
    return ret;
  }
}

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

294
295
296
297
298
299
300
301
302
bool CSR::Load(dmlc::Stream *fs) {
  fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
  return true;
}

void CSR::Save(dmlc::Stream *fs) const {
  fs->Write(adj_);
}

303
304
305
306
307
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
308
309
COO::COO(int64_t num_vertices, IdArray src, IdArray dst,
        bool row_sorted, bool col_sorted) {
310
311
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
312
  CHECK_EQ(src->shape[0], dst->shape[0]);
313
314
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst,
                         aten::NullArray(), row_sorted, col_sorted};
315
316
317
}

bool COO::IsMultigraph() const {
318
  return aten::COOHasDuplicate(adj_);
319
320
}

321
322
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
323
324
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
325
326
327
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

328
EdgeArray COO::FindEdges(IdArray eids) const {
329
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
330
331
  BUG_ON(aten::IsNullArray(adj_.data)) <<
    "FindEdges requires the internal COO matrix not having EIDs.";
332
333
334
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
335
336
}

337
EdgeArray COO::Edges(const std::string &order) const {
338
339
340
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
341
342
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
343
344
}

345
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
346
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
347
348
  COOPtr subcoo;
  IdArray induced_nodes;
349
  if (!preserve_nodes) {
350
351
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
352
    induced_nodes = aten::Relabel_({new_src, new_dst});
353
    const auto new_nnodes = induced_nodes->shape[0];
354
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
355
  } else {
356
357
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
358
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
359
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
360
  }
Da Zheng's avatar
Da Zheng committed
361
362
363
364
365
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
366
367
}

368
CSRPtr COO::ToCSR() const {
369
370
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
371
372
}

373
374
375
376
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
377
378
379
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
380
381
382
383
    return ret;
  }
}

384
385
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
386
  return COO();
387
388
}

389
390
391
392
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
393
394
395
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
396
397
398
399
    return ret;
  }
}

400
401
402
403
404
405
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

406
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
407
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

455
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
475
476
}

477
478
479
480
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
481
482
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
483
484
}

485
486
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
487
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
488
489
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
490
491
492
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
493
494
495
496
497
498
499
500
501
502
503
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
504
  } else {
505
506
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
507
508
509
  }
}

510
511
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
512
  CSRPtr csr(new CSR(indptr, indices, edge_ids));
513
514
515
516
517
518
519
520
521
522
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

523
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {
524
525
526
527
528
529
  // If the shared memory graph index doesn't exist, we return null directly.
#ifndef _WIN32
  if (!SharedMemory::Exist(GetSharedMemName(name, "meta"))) {
    return nullptr;
  }
#endif  // _WIN32
530
531
532
533
  GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta"));
  CSRPtr in_csr, out_csr;
  if (meta.has_in_csr) {
    in_csr = CSRPtr(new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
534
  }
535
536
  if (meta.has_out_csr) {
    out_csr = CSRPtr(new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
537
  }
538
  return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));
539
540
541
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
542
543
544
    int64_t num_vertices, IdArray src, IdArray dst,
    bool row_sorted, bool col_sorted) {
  COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted));
545
546
547
548
549
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
550
  if (ig) {
551
    return ig;
552
553
554
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
555
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
556
557
558
  }
}

559
560
561
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
562
563
564
565
566
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
567
568
569
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
570
571
}

572
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g, const std::string &name) {
573
  CSRPtr new_incsr, new_outcsr;
574
575
576
577
578
579
580
581
582
  std::string shared_mem_name = GetSharedMemName(name, "in");
  new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));

  shared_mem_name = GetSharedMemName(name, "out");
  new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));

  auto new_g = ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
  new_g->serialized_shared_meta_ = SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
  return new_g;
583
584
}

585
586
587
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
588
589
590
591
592
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
593
594
595
596
597
598
599
600
601
602
603
604
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
605
606
607
  }
}

608
609
610
611
612
613
614
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
615
616
617
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
      << "Invalid ImmutableGraph Magic Number";
  CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
618
619
620
621
622
623
  return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
624
  fs->Write(GetOutCSR());
625
626
}

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
  aten::CSRMatrix in_csr, out_csr;
  aten::COOMatrix coo;

  if (in_csr_)
    in_csr = GetInCSR()->ToCSRMatrix();
  if (out_csr_)
    out_csr = GetOutCSR()->ToCSRMatrix();
  if (coo_)
    coo = GetCOO()->ToCOOMatrix();

  auto g = UnitGraph::CreateHomographFrom(
      in_csr, out_csr, coo,
      in_csr_ != nullptr,
      out_csr_ != nullptr,
      coo_ != nullptr);
  return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
    CHECK(ig) << "graph is not readonly";
    *rv = HeteroGraphRef(ig->AsHeteroGraph());
  });

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
669
    std::string name = args[1];
670
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
671
    *rv = ImmutableGraph::CopyToSharedMem(ig, name);
672
673
674
675
676
677
678
679
680
681
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

682
}  // namespace dgl