array.cc 36.8 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019-2021 by Contributors
3
4
 * @file array/array.cc
 * @brief DGL array utilities implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/graph_traversal.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
10
#include <dgl/runtime/device_api.h>
11
12
#include <dgl/runtime/shared_mem.h>

13
#include <sstream>
14

15
16
#include "../c_api_common.h"
#include "./arith.h"
17
#include "./array_op.h"
18

19
using namespace dgl::runtime;
20

21
namespace dgl {
22
23
namespace aten {

24
25
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
  return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
26
27
28
29
30
31
32
33
}

IdArray Clone(IdArray arr) {
  IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);
  ret.CopyFrom(arr);
  return ret;
}

34
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
35
  IdArray ret;
36
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
37
38
39
40
41
42
43
44
45
46
47
    if (nbits == 32) {
      ret = impl::Range<XPU, int32_t>(low, high, ctx);
    } else if (nbits == 64) {
      ret = impl::Range<XPU, int64_t>(low, high, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

48
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
49
  IdArray ret;
50
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
51
52
53
54
55
56
57
58
59
60
61
    if (nbits == 32) {
      ret = impl::Full<XPU, int32_t>(val, length, ctx);
    } else if (nbits == 64) {
      ret = impl::Full<XPU, int64_t>(val, length, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

62
template <typename DType>
63
NDArray Full(DType val, int64_t length, DGLContext ctx) {
64
65
66
67
68
69
70
  NDArray ret;
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
    ret = impl::Full<XPU, DType>(val, length, ctx);
  });
  return ret;
}

71
72
73
74
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
75

76
IdArray AsNumBits(IdArray arr, uint8_t bits) {
77
  CHECK(bits == 32 || bits == 64)
78
79
80
81
      << "Invalid ID type. Must be int32 or int64, but got int"
      << static_cast<int>(bits) << ".";
  if (arr->dtype.bits == bits) return arr;
  if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);
82
  IdArray ret;
83
  ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
84
85
    ATEN_ID_TYPE_SWITCH(
        arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });
86
87
88
89
90
91
  });
  return ret;
}

IdArray HStack(IdArray lhs, IdArray rhs) {
  IdArray ret;
92
93
  CHECK_SAME_CONTEXT(lhs, rhs);
  CHECK_SAME_DTYPE(lhs, rhs);
94
95
96
97
98
99
  CHECK_EQ(lhs->shape[0], rhs->shape[0]);
  auto device = runtime::DeviceAPI::Get(lhs->ctx);
  const auto& ctx = lhs->ctx;
  ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
    const int64_t len = lhs->shape[0];
    ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
100
101
102
103
104
105
    device->CopyDataFromTo(
        lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,
        ctx, lhs->dtype);
    device->CopyDataFromTo(
        rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),
        len * sizeof(IdType), ctx, ctx, lhs->dtype);
Jinjing Zhou's avatar
Jinjing Zhou committed
106
107
108
109
  });
  return ret;
}

110
111
NDArray IndexSelect(NDArray array, IdArray index) {
  NDArray ret;
112
  CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
113
  CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
114
115
116
117
  // if array is not pinned, index has the same context as array
  // if array is pinned, op dispatching depends on the context of index
  CHECK_VALID_CONTEXT(array, index);
  ATEN_XPU_SWITCH_CUDA(index->ctx.device_type, XPU, "IndexSelect", {
118
119
120
121
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
      });
122
123
124
125
126
    });
  });
  return ret;
}

127
template <typename ValueType>
128
ValueType IndexSelect(NDArray array, int64_t index) {
129
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
130
  CHECK(index >= 0 && index < array.NumElements())
131
      << "Index " << index << " is out of bound.";
132
  ValueType ret = 0;
133
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
134
135
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ret = impl::IndexSelect<XPU, DType>(array, index);
136
137
138
139
    });
  });
  return ret;
}
140
141
142
143
144
145
146
147
148
149
template int32_t IndexSelect<int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<uint64_t>(NDArray array, int64_t index);
template float IndexSelect<float>(NDArray array, int64_t index);
template double IndexSelect<double>(NDArray array, int64_t index);

NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
  CHECK(start >= 0 && start < array.NumElements())
150
      << "Index " << start << " is out of bound.";
151
  CHECK(end >= 0 && end <= array.NumElements())
152
      << "Index " << end << " is out of bound.";
153
154
155
156
157
  CHECK_LE(start, end);
  auto device = runtime::DeviceAPI::Get(array->ctx);
  const int64_t len = end - start;
  NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
  ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
158
159
160
    device->CopyDataFromTo(
        array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),
        array->ctx, ret->ctx, array->dtype);
161
162
163
  });
  return ret;
}
164

165
166
NDArray Scatter(NDArray array, IdArray indices) {
  NDArray ret;
167
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Scatter", {
168
169
170
171
172
173
174
175
176
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, {
        ret = impl::Scatter<XPU, DType, IdType>(array, indices);
      });
    });
  });
  return ret;
}

177
178
179
180
181
void Scatter_(IdArray index, NDArray value, NDArray out) {
  CHECK_SAME_DTYPE(value, out);
  CHECK_SAME_CONTEXT(index, value);
  CHECK_SAME_CONTEXT(index, out);
  CHECK_EQ(value->shape[0], index->shape[0]);
182
  if (index->shape[0] == 0) return;
183
184
185
186
187
188
189
190
191
  ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
    ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        impl::Scatter_<XPU, DType, IdType>(index, value, out);
      });
    });
  });
}

192
193
NDArray Repeat(NDArray array, IdArray repeats) {
  NDArray ret;
194
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Repeat", {
195
196
197
198
199
200
201
202
203
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, {
        ret = impl::Repeat<XPU, DType, IdType>(array, repeats);
      });
    });
  });
  return ret;
}

204
205
IdArray Relabel_(const std::vector<IdArray>& arrays) {
  IdArray ret;
206
  ATEN_XPU_SWITCH_CUDA(arrays[0]->ctx.device_type, XPU, "Relabel_", {
207
208
209
210
211
212
213
    ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
      ret = impl::Relabel_<XPU, IdType>(arrays);
    });
  });
  return ret;
}

214
215
216
217
218
219
220
221
222
223
NDArray Concat(const std::vector<IdArray>& arrays) {
  IdArray ret;

  int64_t len = 0, offset = 0;
  for (size_t i = 0; i < arrays.size(); ++i) {
    len += arrays[i]->shape[0];
    CHECK_SAME_DTYPE(arrays[0], arrays[i]);
    CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
  }

224
  NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);
225
226
227
228
229

  auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
  for (size_t i = 0; i < arrays.size(); ++i) {
    ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
      device->CopyDataFromTo(
230
231
232
233
234
235
          static_cast<DType*>(arrays[i]->data), 0,
          static_cast<DType*>(ret_arr->data), offset,
          arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,
          arrays[i]->dtype);

      offset += arrays[i]->shape[0] * sizeof(DType);
236
237
238
239
240
241
    });
  }

  return ret_arr;
}

242
template <typename ValueType>
243
244
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
  std::tuple<NDArray, IdArray, IdArray> ret;
245
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
246
247
248
249
250
251
252
253
254
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));
    });
  });
  return ret;
}

template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
255
256
257
258
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(
    NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(
    NDArray, uint64_t);
259
260
261
262
263
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);

std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
  std::pair<NDArray, IdArray> ret;
264
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "ConcatSlices", {
265
266
267
268
269
270
271
272
273
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {
        ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);
      });
    });
  });
  return ret;
}

274
275
276
277
278
279
280
281
282
283
IdArray CumSum(IdArray array, bool prepend_zero) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "CumSum", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
      ret = impl::CumSum<XPU, IdType>(array, prepend_zero);
    });
  });
  return ret;
}

284
285
286
IdArray NonZero(NDArray array) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
287
288
    ATEN_ID_TYPE_SWITCH(
        array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });
289
290
291
292
  });
  return ret;
}

293
std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
294
295
296
297
298
299
300
  if (array.NumElements() == 0) {
    IdArray idx = NewIdArray(0, array->ctx, 64);
    return std::make_pair(array, idx);
  }
  std::pair<IdArray, IdArray> ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
301
      ret = impl::Sort<XPU, IdType>(array, num_bits);
302
303
304
305
306
    });
  });
  return ret;
}

307
308
std::string ToDebugString(NDArray array) {
  std::ostringstream oss;
309
  NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
310
311
312
313
314
315
  oss << "array([";
  ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
    for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
      oss << a.Ptr<DType>()[i] << ", ";
    }
  });
316
  if (a.NumElements() > 10) oss << "...";
317
318
319
320
  oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")";
  return oss.str();
}

321
322
323
///////////////////////// CSR routines //////////////////////////

bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
324
325
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
326
  bool ret = false;
327
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsNonZero", {
328
329
330
331
332
333
334
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  NDArray ret;
335
336
  CHECK_SAME_DTYPE(csr.indices, row);
  CHECK_SAME_DTYPE(csr.indices, col);
337
338
  CHECK_SAME_CONTEXT(row, col);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRIsNonZero", {
339
340
341
342
343
344
345
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

bool CSRHasDuplicate(CSRMatrix csr) {
  bool ret = false;
346
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRHasDuplicate", {
347
348
349
350
351
352
    ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
  });
  return ret;
}

int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
353
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
354
  int64_t ret = 0;
355
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowNNZ", {
356
357
358
359
360
361
362
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
  NDArray ret;
363
  CHECK_SAME_DTYPE(csr.indices, row);
364
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRGetRowNNZ", {
365
366
367
368
369
370
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
371
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
372
  NDArray ret;
373
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowColumnIndices", {
374
375
376
377
378
379
    ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
380
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
381
  NDArray ret;
382
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowData", {
383
    ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
384
385
386
387
  });
  return ret;
}

388
bool CSRIsSorted(CSRMatrix csr) {
389
  if (csr.indices->shape[0] <= 1) return true;
390
391
392
393
394
395
396
  bool ret = false;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", {
    ret = impl::CSRIsSorted<XPU, IdType>(csr);
  });
  return ret;
}

397
398
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
  NDArray ret;
399
400
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
401
402
  CHECK_SAME_CONTEXT(rows, cols);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
403
    ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
404
405
406
407
  });
  return ret;
}

408
template <typename DType>
409
410
NDArray CSRGetData(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
411
412
413
  NDArray ret;
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
414
415
416
  CHECK_SAME_CONTEXT(rows, cols);
  CHECK_SAME_CONTEXT(rows, weights);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
417
418
    ret =
        impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
419
420
421
422
423
424
425
426
427
  });
  return ret;
}

template NDArray CSRGetData<float>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);

428
429
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray rows, NDArray cols) {
430
431
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
432
  CHECK_SAME_CONTEXT(rows, cols);
433
  std::vector<NDArray> ret;
434
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetDataAndIndices", {
435
    ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
436
437
438
439
440
441
  });
  return ret;
}

CSRMatrix CSRTranspose(CSRMatrix csr) {
  CSRMatrix ret;
442
443
444
445
  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRTranspose", {
    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
      ret = impl::CSRTranspose<XPU, IdType>(csr);
    });
446
447
448
449
450
451
452
  });
  return ret;
}

COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
  COOMatrix ret;
  if (data_as_order) {
453
454
455
456
457
458
    ATEN_XPU_SWITCH_CUDA(
        csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
          ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
            ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
          });
        });
459
  } else {
460
    ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
461
462
463
464
465
466
467
468
469
      ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
        ret = impl::CSRToCOO<XPU, IdType>(csr);
      });
    });
  }
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
470
471
472
  CHECK(start >= 0 && start < csr.num_rows) << "Invalid start index: " << start;
  CHECK(end >= 0 && end <= csr.num_rows) << "Invalid end index: " << end;
  CHECK_GE(end, start);
473
  CSRMatrix ret;
474
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRSliceRows", {
475
    ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
476
477
478
479
480
  });
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
481
  CHECK_SAME_DTYPE(csr.indices, rows);
482
  CSRMatrix ret;
483
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceRows", {
484
    ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
485
486
487
488
489
  });
  return ret;
}

CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
490
491
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
492
  CHECK_SAME_CONTEXT(rows, cols);
493
  CSRMatrix ret;
494
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceMatrix", {
495
    ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
496
497
498
499
  });
  return ret;
}

500
void CSRSort_(CSRMatrix* csr) {
501
502
503
  if (csr->sorted) return;
  ATEN_CSR_SWITCH_CUDA(
      *csr, XPU, IdType, "CSRSort_", { impl::CSRSort_<XPU, IdType>(csr); });
Da Zheng's avatar
Da Zheng committed
504
505
}

506
std::pair<CSRMatrix, NDArray> CSRSortByTag(
507
    const CSRMatrix& csr, IdArray tag, int64_t num_tags) {
508
  CHECK_EQ(csr.indices->shape[0], tag->shape[0])
509
510
      << "The length of the tag array should be equal to the number of "
         "non-zero data.";
511
512
513
514
515
516
517
518
519
520
521
  CHECK_SAME_CONTEXT(csr.indices, tag);
  CHECK_INT(tag, "tag");
  std::pair<CSRMatrix, NDArray> ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSortByTag", {
    ATEN_ID_TYPE_SWITCH(tag->dtype, TagType, {
      ret = impl::CSRSortByTag<XPU, IdType, TagType>(csr, tag, num_tags);
    });
  });
  return ret;
}

522
523
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
Da Zheng's avatar
Da Zheng committed
524
525
526
527
528
529
530
  CSRMatrix ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", {
    ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);
  });
  return ret;
}

531
532
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
  CSRMatrix ret;
533
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
534
535
536
537
538
    ret = impl::CSRRemove<XPU, IdType>(csr, entries);
  });
  return ret;
}

539
COOMatrix CSRRowWiseSampling(
540
541
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
542
  COOMatrix ret;
543
  if (IsNullArray(prob_or_mask)) {
544
545
546
547
548
    ATEN_CSR_SWITCH_CUDA_UVA(
        mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
          ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(
              mat, rows, num_samples, replace);
        });
549
  } else {
550
551
    // prob_or_mask is pinned and rows on GPU is valid
    CHECK_VALID_CONTEXT(prob_or_mask, rows);
552
    ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
553
554
      CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA))
          << "GPU sampling with masks is currently not supported yet.";
555
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
556
          prob_or_mask->dtype, FloatType, "probability or mask", {
557
558
559
            ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
560
561
    });
  }
562
563
564
  return ret;
}

565
COOMatrix CSRRowWisePerEtypeSampling(
566
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
567
568
569
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
570
  COOMatrix ret;
571
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
572
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
573
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
574
      ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
575
576
          mat, rows, eid2etype_offset, num_samples, replace,
          rowwise_etype_sorted);
577
    } else {
578
579
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
580
581
582
583
            ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,
                rowwise_etype_sorted);
          });
584
585
586
587
588
    }
  });
  return ret;
}

589
COOMatrix CSRRowWiseTopk(
590
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
591
  COOMatrix ret;
592
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseTopk", {
593
594
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::CSRRowWiseTopk<XPU, IdType, DType>(
595
596
597
598
599
600
          mat, rows, k, weight, ascending);
    });
  });
  return ret;
}

601
COOMatrix CSRRowWiseSamplingBiased(
602
603
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
604
605
606
  COOMatrix ret;
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
    ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
607
      ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
608
609
610
611
612
613
          mat, rows, num_samples, tag_offset, bias, replace);
    });
  });
  return ret;
}

614
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
615
616
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
617
618
619
620
621
622
623
624
625
626
  CHECK_GT(num_samples, 0) << "Number of samples must be positive";
  CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive";
  std::pair<IdArray, IdArray> result;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGlobalUniformNegativeSampling", {
    result = impl::CSRGlobalUniformNegativeSampling<XPU, IdType>(
        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
  });
  return result;
}

627
628
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
  CSRMatrix ret;
629
630
  CHECK_GT(csrs.size(), 1)
      << "UnionCsr creates a union of multiple CSRMatrixes";
631
632
  // sanity check
  for (size_t i = 1; i < csrs.size(); ++i) {
633
634
635
636
    CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)
        << "UnionCsr requires both CSRMatrix have same number of rows";
    CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)
        << "UnionCsr requires both CSRMatrix have same number of cols";
637
638
639
640
641
642
643
644
645
646
    CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
  }

  ATEN_CSR_SWITCH(csrs[0], XPU, IdType, "UnionCsr", {
    ret = impl::UnionCsr<XPU, IdType>(csrs);
  });
  return ret;
}

647
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {
648
649
650
651
652
653
654
655
656
  std::tuple<CSRMatrix, IdArray, IdArray> ret;

  CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRToSimple", {
    ret = impl::CSRToSimple<XPU, IdType>(sorted_csr);
  });
  return ret;
}

657
658
///////////////////////// COO routines //////////////////////////

659
660
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
  bool ret = false;
661
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
662
663
664
665
666
667
668
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
  NDArray ret;
669
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
670
671
672
673
674
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

675
676
bool COOHasDuplicate(COOMatrix coo) {
  bool ret = false;
677
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOHasDuplicate", {
678
679
680
681
682
    ret = impl::COOHasDuplicate<XPU, IdType>(coo);
  });
  return ret;
}

683
684
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
  int64_t ret = 0;
685
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
686
687
688
689
690
691
692
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
  NDArray ret;
693
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
694
695
696
697
698
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

699
700
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row) {
701
  std::pair<NDArray, NDArray> ret;
702
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
703
    ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
704
705
706
707
708
709
710
  });
  return ret;
}

std::vector<NDArray> COOGetDataAndIndices(
    COOMatrix coo, NDArray rows, NDArray cols) {
  std::vector<NDArray> ret;
711
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetDataAndIndices", {
712
    ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);
713
714
715
716
  });
  return ret;
}

717
718
719
720
721
722
723
724
NDArray COOGetData(COOMatrix coo, NDArray rows, NDArray cols) {
  NDArray ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
    ret = impl::COOGetData<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

725
COOMatrix COOTranspose(COOMatrix coo) {
726
  return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);
727
728
}

729
730
CSRMatrix COOToCSR(COOMatrix coo) {
  CSRMatrix ret;
731
  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
732
733
    ATEN_ID_TYPE_SWITCH(
        coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });
734
735
736
737
  });
  return ret;
}

738
739
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
  COOMatrix ret;
740
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
741
    ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);
742
743
744
745
746
747
  });
  return ret;
}

COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
  COOMatrix ret;
748
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
749
    ret = impl::COOSliceRows<XPU, IdType>(coo, rows);
750
751
752
753
754
755
  });
  return ret;
}

COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
  COOMatrix ret;
756
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceMatrix", {
757
758
759
760
761
    ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

762
void COOSort_(COOMatrix* mat, bool sort_column) {
763
  if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;
764
765
766
  ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", {
    ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {
      impl::COOSort_<XPU, IdType>(mat, sort_column);
767
    });
768
  });
769
770
771
}

std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
772
  if (coo.row->shape[0] <= 1) return {true, true};
773
774
775
776
  std::pair<bool, bool> ret;
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", {
    ret = impl::COOIsSorted<XPU, IdType>(coo);
  });
777
778
779
  return ret;
}

780
781
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
782
783
784
785
786
787
788
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
    ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
  });
  return ret;
}

789
790
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
  COOMatrix ret;
791
  ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
792
793
794
795
796
    ret = impl::COORemove<XPU, IdType>(coo, entries);
  });
  return ret;
}

797
COOMatrix COORowWiseSampling(
798
799
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
800
  COOMatrix ret;
801
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
802
    if (IsNullArray(prob_or_mask)) {
803
804
      ret = impl::COORowWiseSamplingUniform<XPU, IdType>(
          mat, rows, num_samples, replace);
805
    } else {
806
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
807
          prob_or_mask->dtype, DType, "probability or mask", {
808
809
810
            ret = impl::COORowWiseSampling<XPU, IdType, DType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
811
812
813
814
815
    }
  });
  return ret;
}

816
COOMatrix COORowWisePerEtypeSampling(
817
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
818
819
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
820
  COOMatrix ret;
821
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
822
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
823
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
824
      ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
825
          mat, rows, eid2etype_offset, num_samples, replace);
826
    } else {
827
828
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
829
830
831
832
            ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask,
                replace);
          });
833
834
835
836
837
    }
  });
  return ret;
}

838
839
840
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
  COOMatrix ret;
841
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseTopk", {
842
843
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::COORowWiseTopk<XPU, IdType, DType>(
844
845
          mat, rows, k, weight, ascending);
    });
846
847
848
849
  });
  return ret;
}

850
851
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
  std::pair<COOMatrix, IdArray> ret;
852
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOCoalesce", {
853
854
855
856
857
    ret = impl::COOCoalesce<XPU, IdType>(coo);
  });
  return ret;
}

858
859
860
861
862
863
864
865
866
867
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
  ATEN_XPU_SWITCH_CUDA(coos[0].row->ctx.device_type, XPU, "DisjointUnionCoo", {
    ATEN_ID_TYPE_SWITCH(coos[0].row->dtype, IdType, {
      ret = impl::DisjointUnionCoo<XPU, IdType>(coos);
    });
  });
  return ret;
}

868
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
869
870
871
872
873
874
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", {
    ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);
  });
  return ret;
}
875
876
877

COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
878
879
  CHECK_GT(coos.size(), 1)
      << "UnionCoo creates a union of multiple COOMatrixes";
880
881
  // sanity check
  for (size_t i = 1; i < coos.size(); ++i) {
882
883
884
885
    CHECK_EQ(coos[0].num_rows, coos[i].num_rows)
        << "UnionCoo requires both COOMatrix have same number of rows";
    CHECK_EQ(coos[0].num_cols, coos[i].num_cols)
        << "UnionCoo requires both COOMatrix have same number of cols";
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
  }

  // we assume the number of coos is not large in common cases
  std::vector<IdArray> coo_row;
  std::vector<IdArray> coo_col;
  bool has_data = false;

  for (size_t i = 0; i < coos.size(); ++i) {
    coo_row.push_back(coos[i].row);
    coo_col.push_back(coos[i].col);
    has_data |= COOHasData(coos[i]);
  }

  IdArray row = Concat(coo_row);
  IdArray col = Concat(coo_col);
  IdArray data = NullArray();

  if (has_data) {
    std::vector<IdArray> eid_data;
907
908
909
910
911
    eid_data.push_back(
        COOHasData(coos[0]) ? coos[0].data
                            : Range(
                                  0, coos[0].row->shape[0],
                                  coos[0].row->dtype.bits, coos[0].row->ctx));
912
913
    int64_t num_edges = coos[0].row->shape[0];
    for (size_t i = 1; i < coos.size(); ++i) {
914
915
916
917
918
919
      eid_data.push_back(
          COOHasData(coos[i])
              ? coos[i].data + num_edges
              : Range(
                    num_edges, num_edges + coos[i].row->shape[0],
                    coos[i].row->dtype.bits, coos[i].row->ctx));
920
921
922
923
924
925
926
      num_edges += coos[i].row->shape[0];
    }

    data = Concat(eid_data);
  }

  return COOMatrix(
927
      coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);
928
929
}

930
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {
931
932
  // coo column sorted
  const COOMatrix sorted_coo = COOSort(coo, true);
933
934
935
936
937
938
939
940
941
  const IdArray eids_shuffled =
      COOHasData(sorted_coo)
          ? sorted_coo.data
          : Range(
                0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,
                sorted_coo.row->ctx);
  const auto& coalesced_result = COOCoalesce(sorted_coo);
  const COOMatrix& coalesced_adj = coalesced_result.first;
  const IdArray& count = coalesced_result.second;
942

943
  /**
944
945
   * eids_shuffled actually already contains the mapping from old edge space to
   * the new one:
946
   *
947
948
949
950
951
952
   * * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced
   * into new edge #0.
   * * eids_shuffled[count[0]:count[0] + count[1]] indicates those that
   * coalesced into new edge #1.
   * * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]
   * indicates those that coalesced into new edge #2.
953
954
   * * etc.
   *
955
956
957
   * Here, we need to translate eids_shuffled to an array "eids_remapped" such
   * that eids_remapped[i] indicates the new edge ID the old edge #i is mapped
   * to.  The translation can simply be achieved by (in numpy code):
958
959
960
961
962
963
   *
   *     new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)
   *     eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)
   *     eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled
   */
  const IdArray new_eids = Range(
964
965
      0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,
      coalesced_adj.row->ctx);
966
967
968
  const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);

  COOMatrix ret = COOMatrix(
969
970
      coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,
      coalesced_adj.col, NullArray(), true, true);
971
972
973
  return std::make_tuple(ret, count, eids_remapped);
}

974
///////////////////////// Graph Traverse routines //////////////////////////
975
976
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
977
978
979
980
981
982
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
983
984
985
986
987
988
989
990
991
992
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
993
994
995
996
997
998
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
  Frontiers ret;
1009
1010
1011
1012
1013
1014
1015
1016
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
  ATEN_XPU_SWITCH(
      csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
        ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
          ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
        });
      });
1017
1018
1019
1020
1021
  return ret;
}

Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1022
1023
1024
1025
1026
1027
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1028
1029
1030
1031
1032
1033
1034
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
    });
  });
  return ret;
}
1035

1036
1037
1038
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels) {
1039
  Frontiers ret;
1040
1041
1042
1043
1044
1045
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1046
1047
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
1048
1049
      ret = impl::DGLDFSLabeledEdges<XPU, IdType>(
          csr, source, has_reverse_edge, has_nontree_edge, return_labels);
1050
1051
1052
1053
1054
    });
  });
  return ret;
}

1055
1056
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
1057
1058
1059
1060
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->format;
    });
1061
1062

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows")
1063
1064
1065
1066
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_rows;
    });
1067
1068

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols")
1069
1070
1071
1072
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_cols;
    });
1073
1074

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices")
1075
1076
1077
1078
1079
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      const int64_t i = args[1];
      *rv = spmat->indices[i];
    });
1080
1081

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags")
1082
1083
1084
1085
1086
1087
1088
1089
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      List<Value> flags;
      for (bool flg : spmat->flags) {
        flags.push_back(Value(MakeValue(flg)));
      }
      *rv = flags;
    });
1090
1091

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
1092
1093
1094
1095
1096
1097
1098
1099
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t format = args[0];
      const int64_t nrows = args[1];
      const int64_t ncols = args[2];
      const List<Value> indices = args[3];
      const List<Value> flags = args[4];
      std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
          format, nrows, ncols, ListValueToVector<IdArray>(indices),
1100
          ListValueToVector<bool>(flags)));
1101
1102
      *rv = SparseMatrixRef(spmat);
    });
1103

1104
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
1105
1106
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const std::string name = args[0];
1107
#ifndef _WIN32
1108
      *rv = SharedMemory::Exist(name);
1109
#else
1110
      *rv = false;
1111
#endif  // _WIN32
1112
    });
1113

1114
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
1115
1116
1117
1118
1119
1120
1121
1122
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray array = args[0];
      CHECK_EQ(array->dtype.code, kDGLUInt);
      std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
      DGLDataType dtype = array->dtype;
      dtype.code = kDGLInt;
      *rv = array.CreateView(shape, dtype, 0);
    });
1123

1124
1125
}  // namespace aten
}  // namespace dgl
1126

1127
std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {
1128
1129
  return os << dgl::aten::ToDebugString(array);
}