immutable_graph.cc 24.7 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

#include <dgl/immutable_graph.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
10
11
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
12
13
14
15
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
16
17
18

#include "../c_api_common.h"

19
20
using namespace dgl::runtime;

21
namespace dgl {
22
namespace {
23
24
25
26
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

27
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
28
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
29
#ifndef _WIN32
30
31
32
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
33
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
34
35
36
37
38
39
40
41
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
42
#else
43
44
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
45
46
#endif  // _WIN32
}
47
48
49
50
51
52
53
54
55
56
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
  : is_multigraph_(is_multigraph) {
57
58
59
60
61
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
62
  adj_.sorted = false;
63
64
}

65
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
66
67
68
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
69
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
70
71
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
72
  adj_.sorted = false;
73
74
75
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
76
  : is_multigraph_(is_multigraph) {
77
78
79
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
80
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
81
82
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
83
  adj_.sorted = false;
84
85
86
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
87
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
88
89
90
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
91
92
93
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
94
95
96
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
97
      shared_mem_name, num_verts, num_edges, true);
98
  // copy the given data into the shared memory arrays
99
100
101
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
102
  adj_.sorted = false;
103
104
105
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
106
107
         const std::string &shared_mem_name): is_multigraph_(is_multigraph),
         shared_mem_name_(shared_mem_name) {
108
109
110
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
111
112
113
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
114
115
116
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
117
      shared_mem_name, num_verts, num_edges, true);
118
  // copy the given data into the shared memory arrays
119
120
121
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
122
  adj_.sorted = false;
123
124
125
126
}

CSR::CSR(const std::string &shared_mem_name,
         int64_t num_verts, int64_t num_edges, bool is_multigraph)
127
  : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
128
129
130
131
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
132
      shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
133
  adj_.sorted = false;
134
135
136
137
138
}

bool CSR::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
139
      return aten::CSRHasDuplicate(adj_);
140
141
    });
}
142

143
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
144
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
145
146
147
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
148
  return EdgeArray{ret_src, ret_dst, ret_eid};
149
150
}

151
EdgeArray CSR::OutEdges(IdArray vids) const {
152
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
153
154
155
156
157
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
158
  return EdgeArray{row, coosubmat.col, coosubmat.data};
159
160
}

161
DegreeArray CSR::OutDegrees(IdArray vids) const {
162
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
163
  return aten::CSRGetRowNNZ(adj_, vids);
164
165
}

166
167
168
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
169
170
171
172
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
173
174
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
175
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
176
177
}

178
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
179
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
180
  CHECK(radius == 1) << "invalid radius: " << radius;
181
  return aten::CSRGetRowColumnIndices(adj_, vid);
182
183
}

184
185
186
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
187
  return aten::CSRGetData(adj_, src, dst);
188
189
}

190
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
191
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
192
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
193
}
194

195
EdgeArray CSR::Edges(const std::string &order) const {
196
  CHECK(order.empty() || order == std::string("srcdst"))
197
    << "CSR only support Edges of order \"srcdst\","
198
    << " but got \"" << order << "\".";
199
  const auto& coo = aten::CSRToCOO(adj_, false);
200
  return EdgeArray{coo.row, coo.col, coo.data};
201
202
}

203
Subgraph CSR::VertexSubgraph(IdArray vids) const {
204
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
205
206
207
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
208
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
209
210
211
212
213
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
214
215
216
}

CSRPtr CSR::Transpose() const {
217
218
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
219
220
221
}

COOPtr CSR::ToCOO() const {
222
223
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
224
225
}

226
227
228
229
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
230
231
232
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
233
234
235
236
237
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

238
239
240
241
242
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
243
    // TODO(zhengda) we need to set sorted_ properly.
244
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
245
246
247
  }
}

248
249
250
251
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
252
253
254
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
255
256
257
258
259
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

280
281
282
283
284
285
286
287
288
bool CSR::Load(dmlc::Stream *fs) {
  fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
  return true;
}

void CSR::Save(dmlc::Stream *fs) const {
  fs->Write(adj_);
}

289
290
291
292
293
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
294
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
295
296
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
297
  CHECK_EQ(src->shape[0], dst->shape[0]);
298
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
299
300
301
}

COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
302
  : is_multigraph_(is_multigraph) {
303
304
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
305
  CHECK_EQ(src->shape[0], dst->shape[0]);
306
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
307
308
309
310
311
}

bool COO::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
312
      return aten::COOHasDuplicate(adj_);
313
    });
314
315
}

316
317
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
318
319
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
320
321
322
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

323
EdgeArray COO::FindEdges(IdArray eids) const {
324
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
325
326
327
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
328
329
}

330
EdgeArray COO::Edges(const std::string &order) const {
331
332
333
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
334
335
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
336
337
}

338
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
339
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
340
341
  COOPtr subcoo;
  IdArray induced_nodes;
342
  if (!preserve_nodes) {
343
344
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
345
    induced_nodes = aten::Relabel_({new_src, new_dst});
346
    const auto new_nnodes = induced_nodes->shape[0];
347
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph()));
348
  } else {
349
350
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
351
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
352
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph()));
353
  }
Da Zheng's avatar
Da Zheng committed
354
355
356
357
358
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
359
360
}

361
CSRPtr COO::ToCSR() const {
362
363
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
364
365
}

366
367
368
369
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
370
371
372
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
373
374
375
376
377
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

378
379
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
380
  return COO();
381
382
}

383
384
385
386
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
387
388
389
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
390
391
392
393
394
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

395
396
397
398
399
400
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

401
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
402
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

450
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
470
471
}

472
473
474
475
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
476
477
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
478
479
}

480
481
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
482
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
483
484
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
485
486
487
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
488
489
490
491
492
493
494
495
496
497
498
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
499
  } else {
500
501
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
502
503
504
  }
}

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
    CSRPtr csr(new CSR(indptr, indices, edge_ids));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
                     GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    const std::string &shared_mem_name, size_t num_vertices,
    size_t num_edges, bool multigraph,
    const std::string &edge_dir) {
  CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
                     multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst) {
  COOPtr coo(new COO(num_vertices, src, dst));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
  COOPtr coo(new COO(num_vertices, src, dst, multigraph));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
593
  if (ig) {
594
    return ig;
595
596
597
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
598
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
599
600
601
  }
}

602
603
604
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
605
606
607
608
609
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
610
611
612
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
613
614
}

615
616
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
    const std::string &edge_dir, const std::string &name) {
617
618
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
619
  if (edge_dir == std::string("in"))
620
    new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
621
  else if (edge_dir == std::string("out"))
622
623
    new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
624
625
}

626
627
628
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
629
630
631
632
633
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
634
635
636
637
638
639
640
641
642
643
644
645
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
646
647
648
  }
}

649
650
651
652
653
654
655
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
656
657
658
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph)
      << "Invalid ImmutableGraph Magic Number";
  CHECK(fs->Read(&out_csr_)) << "Invalid csr matrix";
659
660
661
662
663
664
  return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
665
  fs->Write(GetOutCSR());
666
667
}

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    std::string edge_dir = args[1];
    std::string name = args[2];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

697
}  // namespace dgl