immutable_graph.cc 22 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

#include <dgl/immutable_graph.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
10
#include <dgl/base_heterograph.h>
11
12
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
13
14
15
16
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
17
18

#include "../c_api_common.h"
19
20
#include "heterograph.h"
#include "unit_graph.h"
21

22
23
using namespace dgl::runtime;

24
namespace dgl {
25
namespace {
26
27
28
29
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

30
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
31
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
32
#ifndef _WIN32
33
34
35
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
36
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
37
38
39
40
41
42
43
44
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
45
#else
46
47
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
48
49
#endif  // _WIN32
}
50
51
52
53
54
55
56
57
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

58
CSR::CSR(int64_t num_vertices, int64_t num_edges) {
59
60
61
62
63
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
64
  adj_.sorted = false;
65
66
}

67
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
68
69
70
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
71
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
72
73
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
74
  adj_.sorted = false;
75
76
77
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
78
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
79
80
81
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
82
83
84
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
85
86
87
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
88
      shared_mem_name, num_verts, num_edges, true);
89
  // copy the given data into the shared memory arrays
90
91
92
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
93
  adj_.sorted = false;
94
95
96
}

CSR::CSR(const std::string &shared_mem_name,
97
         int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) {
98
99
100
101
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
102
      shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
103
  adj_.sorted = false;
104
105
106
}

bool CSR::IsMultigraph() const {
107
  return aten::CSRHasDuplicate(adj_);
108
}
109

110
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
111
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
112
113
114
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
115
  return EdgeArray{ret_src, ret_dst, ret_eid};
116
117
}

118
EdgeArray CSR::OutEdges(IdArray vids) const {
119
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
120
121
122
123
124
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
125
  return EdgeArray{row, coosubmat.col, coosubmat.data};
126
127
}

128
DegreeArray CSR::OutDegrees(IdArray vids) const {
129
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
130
  return aten::CSRGetRowNNZ(adj_, vids);
131
132
}

133
134
135
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
136
137
138
139
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
140
141
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
142
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
143
144
}

145
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
146
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
147
  CHECK(radius == 1) << "invalid radius: " << radius;
148
  return aten::CSRGetRowColumnIndices(adj_, vid);
149
150
}

151
152
153
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
154
  return aten::CSRGetData(adj_, src, dst);
155
156
}

157
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
158
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
159
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
160
}
161

162
EdgeArray CSR::Edges(const std::string &order) const {
163
  CHECK(order.empty() || order == std::string("srcdst"))
164
    << "CSR only support Edges of order \"srcdst\","
165
    << " but got \"" << order << "\".";
166
  const auto& coo = aten::CSRToCOO(adj_, false);
167
  return EdgeArray{coo.row, coo.col, coo.data};
168
169
}

170
Subgraph CSR::VertexSubgraph(IdArray vids) const {
171
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
172
173
174
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
175
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
176
177
178
179
180
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
181
182
183
}

CSRPtr CSR::Transpose() const {
184
185
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
186
187
188
}

COOPtr CSR::ToCOO() const {
189
190
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
191
192
}

193
194
195
196
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
197
198
199
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
200
201
202
203
    return ret;
  }
}

204
205
206
207
208
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
209
    // TODO(zhengda) we need to set sorted_ properly.
210
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
211
212
213
  }
}

214
215
216
217
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
218
219
220
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
221
222
223
224
    return ret;
  }
}

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

245
246
247
248
249
250
251
252
253
bool CSR::Load(dmlc::Stream *fs) {
  fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
  return true;
}

void CSR::Save(dmlc::Stream *fs) const {
  fs->Write(adj_);
}

254
255
256
257
258
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
259
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
260
261
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
262
  CHECK_EQ(src->shape[0], dst->shape[0]);
263
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
264
265
266
}

bool COO::IsMultigraph() const {
267
  return aten::COOHasDuplicate(adj_);
268
269
}

270
271
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
272
273
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
274
275
276
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

277
EdgeArray COO::FindEdges(IdArray eids) const {
278
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
279
280
281
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
282
283
}

284
EdgeArray COO::Edges(const std::string &order) const {
285
286
287
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
288
289
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
290
291
}

292
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
293
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
294
295
  COOPtr subcoo;
  IdArray induced_nodes;
296
  if (!preserve_nodes) {
297
298
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
299
    induced_nodes = aten::Relabel_({new_src, new_dst});
300
    const auto new_nnodes = induced_nodes->shape[0];
301
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
302
  } else {
303
304
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
305
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
306
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
307
  }
Da Zheng's avatar
Da Zheng committed
308
309
310
311
312
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
313
314
}

315
CSRPtr COO::ToCSR() const {
316
317
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
318
319
}

320
321
322
323
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
324
325
326
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
327
328
329
330
    return ret;
  }
}

331
332
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
333
  return COO();
334
335
}

336
337
338
339
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
340
341
342
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
343
344
345
346
    return ret;
  }
}

347
348
349
350
351
352
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

353
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
354
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

402
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
422
423
}

424
425
426
427
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
428
429
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
430
431
}

432
433
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
434
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
435
436
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
437
438
439
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
440
441
442
443
444
445
446
447
448
449
450
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
451
  } else {
452
453
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
454
455
456
  }
}

457
458
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
459
  CSRPtr csr(new CSR(indptr, indices, edge_ids));
460
461
462
463
464
465
466
467
468
469
470
471
472
473
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir,
    const std::string &shared_mem_name) {
474
  CSRPtr csr(new CSR(indptr, indices, edge_ids,
475
476
477
478
479
480
481
482
483
484
485
486
487
                     GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    const std::string &shared_mem_name, size_t num_vertices,
488
489
    size_t num_edges, const std::string &edge_dir) {
  CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges));
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst) {
  COOPtr coo(new COO(num_vertices, src, dst));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
508
  if (ig) {
509
    return ig;
510
511
512
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
513
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
514
515
516
  }
}

517
518
519
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
520
521
522
523
524
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
525
526
527
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
528
529
}

530
531
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
    const std::string &edge_dir, const std::string &name) {
532
533
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
534
  if (edge_dir == std::string("in"))
535
    new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
536
  else if (edge_dir == std::string("out"))
537
538
    new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
539
540
}

541
542
543
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
544
545
546
547
548
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
549
550
551
552
553
554
555
556
557
558
559
560
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
561
562
563
  }
}

564
565
566
567
568
569
570
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
571
572
573
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
      << "Invalid ImmutableGraph Magic Number";
  CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
574
575
576
577
578
579
  return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
580
  fs->Write(GetOutCSR());
581
582
}

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
  aten::CSRMatrix in_csr, out_csr;
  aten::COOMatrix coo;

  if (in_csr_)
    in_csr = GetInCSR()->ToCSRMatrix();
  if (out_csr_)
    out_csr = GetOutCSR()->ToCSRMatrix();
  if (coo_)
    coo = GetCOO()->ToCOOMatrix();

  auto g = UnitGraph::CreateHomographFrom(
      in_csr, out_csr, coo,
      in_csr_ != nullptr,
      out_csr_ != nullptr,
      coo_ != nullptr);
  return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
    CHECK(ig) << "graph is not readonly";
    *rv = HeteroGraphRef(ig->AsHeteroGraph());
  });

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    std::string edge_dir = args[1];
    std::string name = args[2];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

639
}  // namespace dgl