immutable_graph.cc 18.5 KB
Newer Older
1
2
3
4
5
6
7
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/immutable_graph.cc
 * \brief DGL immutable graph index implementation
 */

#include <dgl/immutable_graph.h>
8
9
10
11
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
12
13
14
15

#include "../c_api_common.h"

namespace dgl {
16
17
namespace {
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
18
  const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
19
#ifndef _WIN32
20
21
22
  const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);

  IdArray sm_array = IdArray::EmptyShared(
23
      shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
24
25
26
27
28
29
30
31
  // Create views from the shared memory array. Note that we don't need to save
  //   the sm_array because the refcount is maintained by the view arrays.
  IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
  IdArray indices = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1) * sizeof(dgl_id_t));
  IdArray edge_ids = sm_array.CreateView({num_edges}, DLDataType{kDLInt, 64, 1},
      (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
  return std::make_tuple(indptr, indices, edge_ids);
32
#else
33
34
  LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
  return {};
35
36
#endif  // _WIN32
}
37
38
39
40
41
42
43
44
45
46
}  // namespace

//////////////////////////////////////////////////////////
//
// CSR graph implementation
//
//////////////////////////////////////////////////////////

CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
  : is_multigraph_(is_multigraph) {
47
48
49
50
51
  CHECK(!(num_vertices == 0 && num_edges != 0));
  adj_ = aten::CSRMatrix{num_vertices, num_vertices,
                         aten::NewIdArray(num_vertices + 1),
                         aten::NewIdArray(num_edges),
                         aten::NewIdArray(num_edges)};
52
53
}

54
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
55
56
57
58
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
59
60
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
61
62
63
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
64
  : is_multigraph_(is_multigraph) {
65
66
67
68
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
69
70
  const int64_t N = indptr->shape[0] - 1;
  adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
71
72
73
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
74
         const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
75
76
77
78
79
80
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
81
82
83
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
84
      shared_mem_name, num_verts, num_edges, true);
85
  // copy the given data into the shared memory arrays
86
87
88
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
89
90
91
}

CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
92
93
         const std::string &shared_mem_name): is_multigraph_(is_multigraph),
         shared_mem_name_(shared_mem_name) {
94
95
96
97
98
99
  CHECK(IsValidIdArray(indptr));
  CHECK(IsValidIdArray(indices));
  CHECK(IsValidIdArray(edge_ids));
  CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
  const int64_t num_verts = indptr->shape[0] - 1;
  const int64_t num_edges = indices->shape[0];
100
101
102
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
103
      shared_mem_name, num_verts, num_edges, true);
104
  // copy the given data into the shared memory arrays
105
106
107
  adj_.indptr.CopyFrom(indptr);
  adj_.indices.CopyFrom(indices);
  adj_.data.CopyFrom(edge_ids);
108
109
110
111
}

CSR::CSR(const std::string &shared_mem_name,
         int64_t num_verts, int64_t num_edges, bool is_multigraph)
112
  : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
113
114
115
116
  CHECK(!(num_verts == 0 && num_edges != 0));
  adj_.num_rows = num_verts;
  adj_.num_cols = num_verts;
  std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
117
      shared_mem_name, num_verts, num_edges, false);
118
119
120
121
122
}

bool CSR::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
123
      return aten::CSRHasDuplicate(adj_);
124
125
    });
}
126

127
CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const {
128
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
129
130
131
132
  IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
  IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
  IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
  return CSR::EdgeArray{ret_src, ret_dst, ret_eid};
133
134
}

135
CSR::EdgeArray CSR::OutEdges(IdArray vids) const {
136
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
137
138
139
140
141
142
  auto csrsubmat = aten::CSRSliceRows(adj_, vids);
  auto coosubmat = aten::CSRToCOO(csrsubmat, false);
  // Note that the row id in the csr submat is relabled, so
  // we need to recover it using an index select.
  auto row = aten::IndexSelect(vids, coosubmat.row);
  return CSR::EdgeArray{row, coosubmat.col, coosubmat.data};
143
144
}

145
DegreeArray CSR::OutDegrees(IdArray vids) const {
146
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
147
  return aten::CSRGetRowNNZ(adj_, vids);
148
149
}

150
151
152
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
  CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
153
154
155
156
157
158
159
  return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
  CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array.";
  CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
  return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
160
161
}

162
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
163
  CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
164
  CHECK(radius == 1) << "invalid radius: " << radius;
165
  return aten::CSRGetRowColumnIndices(adj_, vid);
166
167
}

168
169
170
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
  CHECK(HasVertex(src)) << "invalid vertex: " << src;
  CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
171
  return aten::CSRGetData(adj_, src, dst);
172
173
}

174
CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
175
176
  const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
  return CSR::EdgeArray{arrs[0], arrs[1], arrs[2]};
177
}
178

179
180
CSR::EdgeArray CSR::Edges(const std::string &order) const {
  CHECK(order.empty() || order == std::string("srcdst"))
181
    << "CSR only support Edges of order \"srcdst\","
182
    << " but got \"" << order << "\".";
183
184
  const auto& coo = aten::CSRToCOO(adj_, false);
  return CSR::EdgeArray{coo.row, coo.col, coo.data};
185
186
}

187
188
Subgraph CSR::VertexSubgraph(IdArray vids) const {
  CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
189
190
191
192
  const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
  IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
  CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
  return Subgraph{subcsr, vids, submat.data};
193
194
195
}

CSRPtr CSR::Transpose() const {
196
197
  const auto& trans = aten::CSRTranspose(adj_);
  return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
198
199
200
}

COOPtr CSR::ToCOO() const {
201
202
  const auto& coo = aten::CSRToCOO(adj_, true);
  return COOPtr(new COO(NumVertices(), coo.row, coo.col));
203
204
}

205
206
207
208
CSR CSR::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
209
210
211
    CSR ret(adj_.indptr.CopyTo(ctx),
            adj_.indices.CopyTo(ctx),
            adj_.data.CopyTo(ctx));
212
213
214
215
216
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

217
218
219
220
221
CSR CSR::CopyToSharedMem(const std::string &name) const {
  if (IsSharedMem()) {
    CHECK(name == shared_mem_name_);
    return *this;
  } else {
222
    return CSR(adj_.indptr, adj_.indices, adj_.data, name);
223
224
225
  }
}

226
227
228
229
CSR CSR::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
230
231
232
    CSR ret(aten::AsNumBits(adj_.indptr, bits),
            aten::AsNumBits(adj_.indices, bits),
            aten::AsNumBits(adj_.data, bits));
233
234
235
236
237
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
  // TODO(minjie): This still assumes the data type and device context
  //   of this graph. Should fix later.
  const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
  const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
  const dgl_id_t start = indptr_data[vid];
  const dgl_id_t end = indptr_data[vid + 1];
  return DGLIdIters(eid_data + start, eid_data + end);
}

258
259
260
261
262
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
263
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
264
265
266
  CHECK(IsValidIdArray(src));
  CHECK(IsValidIdArray(dst));
  CHECK_EQ(src->shape[0], dst->shape[0]);
267
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
268
269
270
}

COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
271
  : is_multigraph_(is_multigraph) {
272
273
274
  CHECK(IsValidIdArray(src));
  CHECK(IsValidIdArray(dst));
  CHECK_EQ(src->shape[0], dst->shape[0]);
275
  adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
276
277
278
279
280
}

bool COO::IsMultigraph() const {
  // The lambda will be called the first time to initialize the is_multigraph flag.
  return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
281
      return aten::COOHasDuplicate(adj_);
282
    });
283
284
}

285
286
287
288
289
290
291
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
  CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
  const auto src = aten::IndexSelect(adj_.row, eid);
  const auto dst = aten::IndexSelect(adj_.col, eid);
  return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

292
COO::EdgeArray COO::FindEdges(IdArray eids) const {
293
  CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
294
295
296
  return EdgeArray{aten::IndexSelect(adj_.row, eids),
                   aten::IndexSelect(adj_.col, eids),
                   eids};
297
298
}

299
300
301
302
COO::EdgeArray COO::Edges(const std::string &order) const {
  CHECK(order.empty() || order == std::string("eid"))
    << "COO only support Edges of order \"eid\", but got \""
    << order << "\".";
303
304
  IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
  return EdgeArray{adj_.row, adj_.col, rst_eid};
305
306
}

307
308
309
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
  CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
  if (!preserve_nodes) {
310
311
312
313
314
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
    IdArray induced_nodes = aten::Relabel_({new_src, new_dst});
    const auto new_nnodes = induced_nodes->shape[0];
    COOPtr subcoo(new COO(new_nnodes, new_src, new_dst));
315
316
    return Subgraph{subcoo, induced_nodes, eids};
  } else {
317
318
319
    IdArray new_src = aten::IndexSelect(adj_.row, eids);
    IdArray new_dst = aten::IndexSelect(adj_.col, eids);
    IdArray induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
320
321
322
    COOPtr subcoo(new COO(NumVertices(), new_src, new_dst));
    return Subgraph{subcoo, induced_nodes, eids};
  }
323
324
}

325
CSRPtr COO::ToCSR() const {
326
327
  const auto& csr = aten::COOToCSR(adj_);
  return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
328
329
}

330
331
332
333
COO COO::CopyTo(const DLContext& ctx) const {
  if (Context() == ctx) {
    return *this;
  } else {
334
335
336
    COO ret(NumVertices(),
            adj_.row.CopyTo(ctx),
            adj_.col.CopyTo(ctx));
337
338
339
340
341
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

342
343
COO COO::CopyToSharedMem(const std::string &name) const {
  LOG(FATAL) << "COO doesn't supprt shared memory yet";
344
  return COO();
345
346
}

347
348
349
350
COO COO::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
351
352
353
    COO ret(NumVertices(),
            aten::AsNumBits(adj_.row, bits),
            aten::AsNumBits(adj_.col, bits));
354
355
356
357
358
    ret.is_multigraph_ = is_multigraph_;
    return ret;
  }
}

359
360
361
362
363
364
//////////////////////////////////////////////////////////
//
// immutable graph implementation
//
//////////////////////////////////////////////////////////

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
  CHECK(IsValidIdArray(vids)) << "Invalid id array input";
  return aten::LT(vids, NumVertices());
}

CSRPtr ImmutableGraph::GetInCSR() const {
  if (!in_csr_) {
    if (out_csr_) {
      const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
      if (out_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
    }
  }
  return in_csr_;
}

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
  if (!out_csr_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
      if (in_csr_->IsSharedMem())
        LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
                     << "It may dramatically increase memory consumption.";
    } else {
      CHECK(coo_) << "None of CSR, COO exist";
      const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
    }
  }
  return out_csr_;
}

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
  if (!coo_) {
    if (in_csr_) {
      const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
    } else {
      CHECK(out_csr_) << "Both CSR are missing.";
      const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
    }
  }
  return coo_;
}

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
  if (order.empty()) {
    // arbitrary order
    if (in_csr_) {
      // transpose
      const auto& edges = in_csr_->Edges(order);
      return EdgeArray{edges.dst, edges.src, edges.id};
    } else {
      return AnyGraph()->Edges(order);
    }
  } else if (order == std::string("srcdst")) {
    // TODO(minjie): CSR only guarantees "src" to be sorted.
    //   Maybe we should relax this requirement?
    return GetOutCSR()->Edges(order);
  } else if (order == std::string("eid")) {
    return GetCOO()->Edges(order);
  } else {
    LOG(FATAL) << "Unsupported order request: " << order;
  }
  return {};
434
435
}

436
437
438
439
440
441
Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
  // We prefer to generate a subgraph from out-csr.
  auto sg = GetOutCSR()->VertexSubgraph(vids);
  CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
  return Subgraph{GraphPtr(new ImmutableGraph(subcsr)),
                  sg.induced_vertices, sg.induced_edges};
442
443
}

444
Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
445
  // We prefer to generate a subgraph from out-csr.
446
  auto sg = GetCOO()->EdgeSubgraph(eids, preserve_nodes);
447
448
449
  COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
  return Subgraph{GraphPtr(new ImmutableGraph(subcoo)),
                  sg.induced_vertices, sg.induced_edges};
450
451
452
}

std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
453
454
455
456
457
458
459
460
461
462
463
  // TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
  //   src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
  //   is equal to in edge CSR.
  //   We have this behavior because previously we use framework's SPMM and we don't cache
  //   reverse adj. This is not intuitive and also not consistent with networkx's
  //   to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
  //   behavior and make row for src and col for dst.
  if (fmt == std::string("csr")) {
    return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
  } else if (fmt == std::string("coo")) {
    return GetCOO()->GetAdj(!transpose, fmt);
464
  } else {
465
466
    LOG(FATAL) << "unsupported adjacency matrix format: " << fmt;
    return {};
467
468
469
  }
}

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
ImmutableGraph ImmutableGraph::ToImmutable(const GraphInterface* graph) {
  const ImmutableGraph* ig = dynamic_cast<const ImmutableGraph*>(graph);
  if (ig) {
    return *ig;
  } else {
    const auto& adj = graph->GetAdj(true, "csr");
    CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
    return ImmutableGraph(nullptr, csr);
  }
}

ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const {
  if (ctx == Context()) {
    return *this;
  }
  // TODO(minjie): since we don't have GPU implementation of COO<->CSR,
  //   we make sure that this graph (on CPU) has materialized CSR,
  //   and then copy them to other context (usually GPU). This should
  //   be fixed later.
  CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx)));
  CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx)));
491
492
493
494
495
496
497
  return ImmutableGraph(new_incsr, new_outcsr);
}

ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir,
                                               const std::string &name) const {
  CSRPtr new_incsr, new_outcsr;
  std::string shared_mem_name = GetSharedMemName(name, edge_dir);
498
  if (edge_dir == std::string("in"))
499
    new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name)));
500
  else if (edge_dir == std::string("out"))
501
502
    new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name)));
  return ImmutableGraph(new_incsr, new_outcsr, name);
503
504
505
506
507
508
509
510
511
512
513
514
}

ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const {
  if (NumBits() == bits) {
    return *this;
  } else {
    // TODO(minjie): since we don't have int32 operations,
    //   we make sure that this graph (on CPU) has materialized CSR,
    //   and then copy them to other context (usually GPU). This should
    //   be fixed later.
    CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits)));
    CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits)));
515
    return ImmutableGraph(new_incsr, new_outcsr);
516
517
518
  }
}

519
}  // namespace dgl