immutable_graph.cc 24.7 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

7
#include <dgl/packed_func_ext.h>
8
#include <dgl/immutable_graph.h>
9
10
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
11
12
13
14
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
15
16
17

#include "../c_api_common.h"

18
19
using namespace dgl::runtime;

20
namespace dgl {
21
namespace {
22
23
24
25
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
  return name + "_" + edge_dir;
}

26
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
27
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
28
#ifndef _WIN32
29
30
31
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
32
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
33
34
35
36
37
38
39
40
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
41
#else
42
43
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
44
45
#endif  // _WIN32
}
46
47
48
49
50
51
52
53
54
55
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
  : is_multigraph_(is_multigraph) {
56
57
58
59
60
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
Da Zheng's avatar
Da Zheng committed
61
  adj_.sorted = false;
62
63
}

64
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
65
66
67
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
68
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
69
70
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
71
  adj_.sorted = false;
72
73
74
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
75
  : is_multigraph_(is_multigraph) {
76
77
78
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
79
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
80
81
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
Da Zheng's avatar
Da Zheng committed
82
  adj_.sorted = false;
83
84
85
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
86
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
87
88
89
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
90
91
92
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
93
94
95
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
96
      shared_mem_name, num_verts, num_edges, true);
97
  // copy the given data into the shared memory arrays
98
99
100
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
101
  adj_.sorted = false;
102
103
104
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
105
106
         const std::string &shared_mem_name): is_multigraph_(is_multigraph),
         shared_mem_name_(shared_mem_name) {
107
108
109
  CHECK(aten::IsValidIdArray(indptr));
  CHECK(aten::IsValidIdArray(indices));
  CHECK(aten::IsValidIdArray(edge_ids));
110
111
112
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
113
114
115
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
116
      shared_mem_name, num_verts, num_edges, true);
117
  // copy the given data into the shared memory arrays
118
119
120
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
Da Zheng's avatar
Da Zheng committed
121
  adj_.sorted = false;
122
123
124
125
}

CSR::CSR(const std::string &shared_mem_name,
         int64_t num_verts, int64_t num_edges, bool is_multigraph)
126
  : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
127
128
129
130
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
131
      shared_mem_name, num_verts, num_edges, false);
Da Zheng's avatar
Da Zheng committed
132
  adj_.sorted = false;
133
134
135
136
137
}

bool CSR::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
138
      return aten::CSRHasDuplicate(adj_);
139
140
    });
}
141

142
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
143
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
144
145
146
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
147
  return EdgeArray{ret_src, ret_dst, ret_eid};
148
149
}

150
EdgeArray CSR::OutEdges(IdArray vids) const {
151
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
152
153
154
155
156
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
157
  return EdgeArray{row, coosubmat.col, coosubmat.data};
158
159
}

160
DegreeArray CSR::OutDegrees(IdArray vids) const {
161
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
162
  return aten::CSRGetRowNNZ(adj_, vids);
163
164
}

165
166
167
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
168
169
170
171
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
172
173
  CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
174
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
175
176
}

177
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
178
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
179
  CHECK(radius == 1) << "invalid radius: " << radius;
180
  return aten::CSRGetRowColumnIndices(adj_, vid);
181
182
}

183
184
185
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
186
  return aten::CSRGetData(adj_, src, dst);
187
188
}

189
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
190
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
191
  return EdgeArray{arrs[0], arrs[1], arrs[2]};
192
}
193

194
EdgeArray CSR::Edges(const std::string &order) const {
195
  CHECK(order.empty() || order == std::string("srcdst"))
196
    << "CSR only support Edges of order \"srcdst\","
197
    << " but got \"" << order << "\".";
198
  const auto& coo = aten::CSRToCOO(adj_, false);
199
  return EdgeArray{coo.row, coo.col, coo.data};
200
201
}

202
Subgraph CSR::VertexSubgraph(IdArray vids) const {
203
  CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
204
205
206
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
Da Zheng's avatar
Da Zheng committed
207
  subcsr->adj_.sorted = this->adj_.sorted;
Da Zheng's avatar
Da Zheng committed
208
209
210
211
212
  Subgraph subg;
  subg.graph = subcsr;
  subg.induced_vertices = vids;
  subg.induced_edges = submat.data;
  return subg;
213
214
215
}

CSRPtr CSR::Transpose() const {
216
217
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
218
219
220
}

COOPtr CSR::ToCOO() const {
221
222
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
223
224
}

225
226
227
228
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
229
230
231
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
232
233
234
235
236
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

237
238
239
240
241
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
Da Zheng's avatar
Da Zheng committed
242
    // TODO(zhengda) we need to set sorted_ properly.
243
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
244
245
246
  }
}

247
248
249
250
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
251
252
253
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
254
255
256
257
258
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

279
280
281
282
283
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
284
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
285
286
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
287
  CHECK_EQ(src->shape[0], dst->shape[0]);
288
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
289
290
291
}

COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
292
  : is_multigraph_(is_multigraph) {
293
294
  CHECK(aten::IsValidIdArray(src));
  CHECK(aten::IsValidIdArray(dst));
295
  CHECK_EQ(src->shape[0], dst->shape[0]);
296
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
297
298
299
300
301
}

bool COO::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
302
      return aten::COOHasDuplicate(adj_);
303
    });
304
305
}

306
307
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
308
309
  const dgl_id_t src = aten::IndexSelect<dgl_id_t>(adj_.row, eid);
  const dgl_id_t dst = aten::IndexSelect<dgl_id_t>(adj_.col, eid);
310
311
312
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

313
EdgeArray COO::FindEdges(IdArray eids) const {
314
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
315
316
317
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
318
319
}

320
EdgeArray COO::Edges(const std::string &order) const {
321
322
323
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
324
325
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
326
327
}

328
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
329
  CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array.";
Da Zheng's avatar
Da Zheng committed
330
331
  COOPtr subcoo;
  IdArray induced_nodes;
332
  if (!preserve_nodes) {
333
334
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
335
    induced_nodes = aten::Relabel_({new_src, new_dst});
336
    const auto new_nnodes = induced_nodes->shape[0];
337
    subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph()));
338
  } else {
339
340
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
Da Zheng's avatar
Da Zheng committed
341
    induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
342
    subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph()));
343
  }
Da Zheng's avatar
Da Zheng committed
344
345
346
347
348
  Subgraph subg;
  subg.graph = subcoo;
  subg.induced_vertices = induced_nodes;
  subg.induced_edges = eids;
  return subg;
349
350
}

351
CSRPtr COO::ToCSR() const {
352
353
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
354
355
}

356
357
358
359
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
360
361
362
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
363
364
365
366
367
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

368
369
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
370
  return COO();
371
372
}

373
374
375
376
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
377
378
379
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
380
381
382
383
384
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

385
386
387
388
389
390
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

391
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
392
  CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input";
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

440
EdgeArray ImmutableGraph::Edges(const std::string &order) const {
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
460
461
}

462
463
464
465
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
Da Zheng's avatar
Da Zheng committed
466
467
  sg.graph = GraphPtr(new ImmutableGraph(subcsr));
  return sg;
468
469
}

470
471
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
472
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
Da Zheng's avatar
Da Zheng committed
473
474
  sg.graph = GraphPtr(new ImmutableGraph(subcoo));
  return sg;
475
476
477
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
478
479
480
481
482
483
484
485
486
487
488
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
489
  } else {
490
491
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
492
493
494
  }
}

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
    CSRPtr csr(new CSR(indptr, indices, edge_ids));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    IdArray indptr, IdArray indices, IdArray edge_ids,
    bool multigraph, const std::string &edge_dir,
    const std::string &shared_mem_name) {
  CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
                     GetSharedMemName(shared_mem_name, edge_dir)));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
    const std::string &shared_mem_name, size_t num_vertices,
    size_t num_edges, bool multigraph,
    const std::string &edge_dir) {
  CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
                     multigraph));
  if (edge_dir == "in") {
    return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
  } else if (edge_dir == "out") {
    return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
  } else {
    LOG(FATAL) << "Unknown edge direction: " << edge_dir;
    return ImmutableGraphPtr();
  }
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst) {
  COOPtr coo(new COO(num_vertices, src, dst));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
    int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
  COOPtr coo(new COO(num_vertices, src, dst, multigraph));
  return std::make_shared<ImmutableGraph>(coo);
}

ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
  ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
583
  if (ig) {
584
    return ig;
585
586
587
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
588
    return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
589
590
591
  }
}

592
593
594
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& ctx) {
  if (ctx == g->Context()) {
    return g;
595
596
597
598
599
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
600
601
602
  CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyTo(ctx)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
603
604
}

605
606
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g,
    const std::string &edge_dir, const std::string &name) {
607
608
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
609
  if (edge_dir == std::string("in"))
610
    new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
611
  else if (edge_dir == std::string("out"))
612
613
    new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
614
615
}

616
617
618
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
  if (g->NumBits() == bits) {
    return g;
619
620
621
622
623
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
624
625
626
627
628
629
630
631
632
633
634
635
    CSRPtr new_incsr = CSRPtr(new CSR(g->GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->AsNumBits(bits)));
    return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
  }
}

ImmutableGraphPtr ImmutableGraph::Reverse() const {
  if (coo_) {
    return ImmutableGraphPtr(new ImmutableGraph(
          out_csr_, in_csr_, coo_->Transpose()));
  } else {
    return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
636
637
638
  }
}

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
  uint64_t magicNum;
  aten::CSRMatrix out_csr_matrix;
  CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
  CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data";
  CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix";
  CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices,
                     out_csr_matrix.data));
  auto g = new ImmutableGraph(nullptr, csr);
  *this = *g;
  return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
  fs->Write(kDGLSerialize_ImGraph);
  fs->Write(GetOutCSR()->ToCSRMatrix());
}

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    const int device_type = args[1];
    const int device_id = args[2];
    DLContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(device_type);
    ctx.device_id = device_id;
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyTo(ig, ctx);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    std::string edge_dir = args[1];
    std::string name = args[2];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name);
  });

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    GraphRef g = args[0];
    int bits = args[1];
    ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
    *rv = ImmutableGraph::AsNumBits(ig, bits);
  });

690
}  // namespace dgl