array.cc 38.3 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019-2022 by Contributors
3
4
 * @file array/array.cc
 * @brief DGL array utilities implementation
5
6
 */
#include <dgl/array.h>
7
#include <dgl/graph_traversal.h>
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
10
#include <dgl/runtime/device_api.h>
11
12
#include <dgl/runtime/shared_mem.h>

13
#include <sstream>
14

15
16
#include "../c_api_common.h"
#include "./arith.h"
17
#include "./array_op.h"
18

19
using namespace dgl::runtime;
20

21
namespace dgl {
22
23
namespace aten {

24
25
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
  return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
26
27
}

28
29
30
31
FloatArray NewFloatArray(int64_t length, DGLContext ctx, uint8_t nbits) {
  return FloatArray::Empty({length}, DGLDataType{kDGLFloat, nbits, 1}, ctx);
}

32
33
34
35
36
37
IdArray Clone(IdArray arr) {
  IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);
  ret.CopyFrom(arr);
  return ret;
}

38
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
39
  IdArray ret;
40
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
41
42
43
44
45
46
47
48
49
50
51
    if (nbits == 32) {
      ret = impl::Range<XPU, int32_t>(low, high, ctx);
    } else if (nbits == 64) {
      ret = impl::Range<XPU, int64_t>(low, high, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

52
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
53
  IdArray ret;
54
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
55
56
57
58
59
60
61
62
63
64
65
    if (nbits == 32) {
      ret = impl::Full<XPU, int32_t>(val, length, ctx);
    } else if (nbits == 64) {
      ret = impl::Full<XPU, int64_t>(val, length, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

66
template <typename DType>
67
NDArray Full(DType val, int64_t length, DGLContext ctx) {
68
69
70
71
72
73
74
  NDArray ret;
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
    ret = impl::Full<XPU, DType>(val, length, ctx);
  });
  return ret;
}

75
76
77
78
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
79

80
IdArray AsNumBits(IdArray arr, uint8_t bits) {
81
  CHECK(bits == 32 || bits == 64)
82
83
84
85
      << "Invalid ID type. Must be int32 or int64, but got int"
      << static_cast<int>(bits) << ".";
  if (arr->dtype.bits == bits) return arr;
  if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);
86
  IdArray ret;
87
  ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
88
89
    ATEN_ID_TYPE_SWITCH(
        arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });
90
91
92
93
94
95
  });
  return ret;
}

IdArray HStack(IdArray lhs, IdArray rhs) {
  IdArray ret;
96
97
  CHECK_SAME_CONTEXT(lhs, rhs);
  CHECK_SAME_DTYPE(lhs, rhs);
98
99
100
101
102
103
  CHECK_EQ(lhs->shape[0], rhs->shape[0]);
  auto device = runtime::DeviceAPI::Get(lhs->ctx);
  const auto& ctx = lhs->ctx;
  ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
    const int64_t len = lhs->shape[0];
    ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
104
105
106
107
108
109
    device->CopyDataFromTo(
        lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,
        ctx, lhs->dtype);
    device->CopyDataFromTo(
        rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),
        len * sizeof(IdType), ctx, ctx, lhs->dtype);
Jinjing Zhou's avatar
Jinjing Zhou committed
110
111
112
113
  });
  return ret;
}

114
115
NDArray IndexSelect(NDArray array, IdArray index) {
  NDArray ret;
116
  CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
117
  CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
118
119
120
121
  // if array is not pinned, index has the same context as array
  // if array is pinned, op dispatching depends on the context of index
  CHECK_VALID_CONTEXT(array, index);
  ATEN_XPU_SWITCH_CUDA(index->ctx.device_type, XPU, "IndexSelect", {
122
123
124
125
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
      });
126
127
128
129
130
    });
  });
  return ret;
}

131
template <typename ValueType>
132
ValueType IndexSelect(NDArray array, int64_t index) {
133
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
134
  CHECK(index >= 0 && index < array.NumElements())
135
      << "Index " << index << " is out of bound.";
136
  ValueType ret = 0;
137
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
138
139
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ret = impl::IndexSelect<XPU, DType>(array, index);
140
141
142
143
    });
  });
  return ret;
}
144
145
146
147
148
149
150
151
152
153
template int32_t IndexSelect<int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<uint64_t>(NDArray array, int64_t index);
template float IndexSelect<float>(NDArray array, int64_t index);
template double IndexSelect<double>(NDArray array, int64_t index);

NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
  CHECK(start >= 0 && start < array.NumElements())
154
      << "Index " << start << " is out of bound.";
155
  CHECK(end >= 0 && end <= array.NumElements())
156
      << "Index " << end << " is out of bound.";
157
158
159
160
161
  CHECK_LE(start, end);
  auto device = runtime::DeviceAPI::Get(array->ctx);
  const int64_t len = end - start;
  NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
  ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
162
163
164
    device->CopyDataFromTo(
        array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),
        array->ctx, ret->ctx, array->dtype);
165
166
167
  });
  return ret;
}
168

169
170
NDArray Scatter(NDArray array, IdArray indices) {
  NDArray ret;
171
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Scatter", {
172
173
174
175
176
177
178
179
180
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, {
        ret = impl::Scatter<XPU, DType, IdType>(array, indices);
      });
    });
  });
  return ret;
}

181
182
183
184
185
void Scatter_(IdArray index, NDArray value, NDArray out) {
  CHECK_SAME_DTYPE(value, out);
  CHECK_SAME_CONTEXT(index, value);
  CHECK_SAME_CONTEXT(index, out);
  CHECK_EQ(value->shape[0], index->shape[0]);
186
  if (index->shape[0] == 0) return;
187
188
189
190
191
192
193
194
195
  ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
    ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        impl::Scatter_<XPU, DType, IdType>(index, value, out);
      });
    });
  });
}

196
197
NDArray Repeat(NDArray array, IdArray repeats) {
  NDArray ret;
198
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Repeat", {
199
200
201
202
203
204
205
206
207
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, {
        ret = impl::Repeat<XPU, DType, IdType>(array, repeats);
      });
    });
  });
  return ret;
}

208
209
IdArray Relabel_(const std::vector<IdArray>& arrays) {
  IdArray ret;
210
  ATEN_XPU_SWITCH_CUDA(arrays[0]->ctx.device_type, XPU, "Relabel_", {
211
212
213
214
215
216
217
    ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
      ret = impl::Relabel_<XPU, IdType>(arrays);
    });
  });
  return ret;
}

218
219
220
221
222
223
224
225
226
227
NDArray Concat(const std::vector<IdArray>& arrays) {
  IdArray ret;

  int64_t len = 0, offset = 0;
  for (size_t i = 0; i < arrays.size(); ++i) {
    len += arrays[i]->shape[0];
    CHECK_SAME_DTYPE(arrays[0], arrays[i]);
    CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
  }

228
  NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);
229
230
231
232
233

  auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
  for (size_t i = 0; i < arrays.size(); ++i) {
    ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
      device->CopyDataFromTo(
234
235
236
237
238
239
          static_cast<DType*>(arrays[i]->data), 0,
          static_cast<DType*>(ret_arr->data), offset,
          arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,
          arrays[i]->dtype);

      offset += arrays[i]->shape[0] * sizeof(DType);
240
241
242
243
244
245
    });
  }

  return ret_arr;
}

246
template <typename ValueType>
247
248
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
  std::tuple<NDArray, IdArray, IdArray> ret;
249
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
250
251
252
253
254
255
256
257
258
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));
    });
  });
  return ret;
}

template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
259
260
261
262
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(
    NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(
    NDArray, uint64_t);
263
264
265
266
267
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);

std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
  std::pair<NDArray, IdArray> ret;
268
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "ConcatSlices", {
269
270
271
272
273
274
275
276
277
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {
        ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);
      });
    });
  });
  return ret;
}

278
279
280
281
282
283
284
285
286
287
IdArray CumSum(IdArray array, bool prepend_zero) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "CumSum", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
      ret = impl::CumSum<XPU, IdType>(array, prepend_zero);
    });
  });
  return ret;
}

288
289
290
IdArray NonZero(NDArray array) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
291
292
    ATEN_ID_TYPE_SWITCH(
        array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });
293
294
295
296
  });
  return ret;
}

297
std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
298
299
300
301
302
303
304
  if (array.NumElements() == 0) {
    IdArray idx = NewIdArray(0, array->ctx, 64);
    return std::make_pair(array, idx);
  }
  std::pair<IdArray, IdArray> ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
305
      ret = impl::Sort<XPU, IdType>(array, num_bits);
306
307
308
309
310
    });
  });
  return ret;
}

311
312
std::string ToDebugString(NDArray array) {
  std::ostringstream oss;
313
  NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
314
315
316
317
318
319
  oss << "array([";
  ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
    for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
      oss << a.Ptr<DType>()[i] << ", ";
    }
  });
320
  if (a.NumElements() > 10) oss << "...";
321
322
323
324
  oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")";
  return oss.str();
}

325
326
327
///////////////////////// CSR routines //////////////////////////

bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
328
329
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
330
  bool ret = false;
331
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsNonZero", {
332
333
334
335
336
337
338
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  NDArray ret;
339
340
  CHECK_SAME_DTYPE(csr.indices, row);
  CHECK_SAME_DTYPE(csr.indices, col);
341
342
  CHECK_SAME_CONTEXT(row, col);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRIsNonZero", {
343
344
345
346
347
348
349
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

bool CSRHasDuplicate(CSRMatrix csr) {
  bool ret = false;
350
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRHasDuplicate", {
351
352
353
354
355
356
    ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
  });
  return ret;
}

int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
357
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
358
  int64_t ret = 0;
359
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowNNZ", {
360
361
362
363
364
365
366
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
  NDArray ret;
367
  CHECK_SAME_DTYPE(csr.indices, row);
368
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRGetRowNNZ", {
369
370
371
372
373
374
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
375
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
376
  NDArray ret;
377
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowColumnIndices", {
378
379
380
381
382
383
    ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
384
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
385
  NDArray ret;
386
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowData", {
387
    ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
388
389
390
391
  });
  return ret;
}

392
bool CSRIsSorted(CSRMatrix csr) {
393
  if (csr.indices->shape[0] <= 1) return true;
394
395
396
397
398
399
400
  bool ret = false;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", {
    ret = impl::CSRIsSorted<XPU, IdType>(csr);
  });
  return ret;
}

401
402
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
  NDArray ret;
403
404
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
405
406
  CHECK_SAME_CONTEXT(rows, cols);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
407
    ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
408
409
410
411
  });
  return ret;
}

412
template <typename DType>
413
414
NDArray CSRGetData(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
415
416
417
  NDArray ret;
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
418
419
420
  CHECK_SAME_CONTEXT(rows, cols);
  CHECK_SAME_CONTEXT(rows, weights);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
421
422
    ret =
        impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
423
424
425
426
427
428
429
430
431
  });
  return ret;
}

template NDArray CSRGetData<float>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);

432
433
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray rows, NDArray cols) {
434
435
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
436
  CHECK_SAME_CONTEXT(rows, cols);
437
  std::vector<NDArray> ret;
438
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetDataAndIndices", {
439
    ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
440
441
442
443
444
445
  });
  return ret;
}

CSRMatrix CSRTranspose(CSRMatrix csr) {
  CSRMatrix ret;
446
447
448
449
  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRTranspose", {
    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
      ret = impl::CSRTranspose<XPU, IdType>(csr);
    });
450
451
452
453
454
455
456
  });
  return ret;
}

COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
  COOMatrix ret;
  if (data_as_order) {
457
458
459
460
461
462
    ATEN_XPU_SWITCH_CUDA(
        csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
          ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
            ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
          });
        });
463
  } else {
464
    ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
465
466
467
468
469
470
471
472
473
      ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
        ret = impl::CSRToCOO<XPU, IdType>(csr);
      });
    });
  }
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
474
475
476
  CHECK(start >= 0 && start < csr.num_rows) << "Invalid start index: " << start;
  CHECK(end >= 0 && end <= csr.num_rows) << "Invalid end index: " << end;
  CHECK_GE(end, start);
477
  CSRMatrix ret;
478
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRSliceRows", {
479
    ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
480
481
482
483
484
  });
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
485
  CHECK_SAME_DTYPE(csr.indices, rows);
486
  CSRMatrix ret;
487
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceRows", {
488
    ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
489
490
491
492
493
  });
  return ret;
}

CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
494
495
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
496
  CHECK_SAME_CONTEXT(rows, cols);
497
  CSRMatrix ret;
498
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceMatrix", {
499
    ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
500
501
502
503
  });
  return ret;
}

504
void CSRSort_(CSRMatrix* csr) {
505
506
507
  if (csr->sorted) return;
  ATEN_CSR_SWITCH_CUDA(
      *csr, XPU, IdType, "CSRSort_", { impl::CSRSort_<XPU, IdType>(csr); });
Da Zheng's avatar
Da Zheng committed
508
509
}

510
std::pair<CSRMatrix, NDArray> CSRSortByTag(
511
    const CSRMatrix& csr, IdArray tag, int64_t num_tags) {
512
  CHECK_EQ(csr.indices->shape[0], tag->shape[0])
513
514
      << "The length of the tag array should be equal to the number of "
         "non-zero data.";
515
516
517
518
519
520
521
522
523
524
525
  CHECK_SAME_CONTEXT(csr.indices, tag);
  CHECK_INT(tag, "tag");
  std::pair<CSRMatrix, NDArray> ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSortByTag", {
    ATEN_ID_TYPE_SWITCH(tag->dtype, TagType, {
      ret = impl::CSRSortByTag<XPU, IdType, TagType>(csr, tag, num_tags);
    });
  });
  return ret;
}

526
527
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
Da Zheng's avatar
Da Zheng committed
528
529
530
531
532
533
534
  CSRMatrix ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", {
    ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);
  });
  return ret;
}

535
536
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
  CSRMatrix ret;
537
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
538
539
540
541
542
    ret = impl::CSRRemove<XPU, IdType>(csr, entries);
  });
  return ret;
}

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, IdArray NIDs) {
  std::pair<COOMatrix, FloatArray> ret;
  ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRLaborSampling", {
    const auto dtype = IsNullArray(prob)
                           ? DGLDataTypeTraits<float>::dtype
                           : prob->dtype;
    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
      ret = impl::CSRLaborSampling<XPU, IdType, FloatType>(
          mat, rows, num_samples, prob, importance_sampling, random_seed, NIDs);
    });
  });
  return ret;
}

559
COOMatrix CSRRowWiseSampling(
560
561
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
562
  COOMatrix ret;
563
  if (IsNullArray(prob_or_mask)) {
564
565
566
567
568
    ATEN_CSR_SWITCH_CUDA_UVA(
        mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
          ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(
              mat, rows, num_samples, replace);
        });
569
  } else {
570
571
    // prob_or_mask is pinned and rows on GPU is valid
    CHECK_VALID_CONTEXT(prob_or_mask, rows);
572
    ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
573
574
      CHECK(!(prob_or_mask->dtype.bits == 8 && XPU == kDGLCUDA))
          << "GPU sampling with masks is currently not supported yet.";
575
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
576
          prob_or_mask->dtype, FloatType, "probability or mask", {
577
578
579
            ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
580
581
    });
  }
582
583
584
  return ret;
}

585
COOMatrix CSRRowWisePerEtypeSampling(
586
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
587
588
589
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
590
  COOMatrix ret;
591
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
592
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
593
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
594
      ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
595
596
          mat, rows, eid2etype_offset, num_samples, replace,
          rowwise_etype_sorted);
597
    } else {
598
599
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
600
601
602
603
            ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,
                rowwise_etype_sorted);
          });
604
605
606
607
608
    }
  });
  return ret;
}

609
COOMatrix CSRRowWiseTopk(
610
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
611
  COOMatrix ret;
612
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseTopk", {
613
614
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::CSRRowWiseTopk<XPU, IdType, DType>(
615
616
617
618
619
620
          mat, rows, k, weight, ascending);
    });
  });
  return ret;
}

621
COOMatrix CSRRowWiseSamplingBiased(
622
623
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
624
625
626
  COOMatrix ret;
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
    ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
627
      ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
628
629
630
631
632
633
          mat, rows, num_samples, tag_offset, bias, replace);
    });
  });
  return ret;
}

634
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
635
636
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
637
638
639
640
641
642
643
644
645
646
  CHECK_GT(num_samples, 0) << "Number of samples must be positive";
  CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive";
  std::pair<IdArray, IdArray> result;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGlobalUniformNegativeSampling", {
    result = impl::CSRGlobalUniformNegativeSampling<XPU, IdType>(
        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
  });
  return result;
}

647
648
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
  CSRMatrix ret;
649
650
  CHECK_GT(csrs.size(), 1)
      << "UnionCsr creates a union of multiple CSRMatrixes";
651
652
  // sanity check
  for (size_t i = 1; i < csrs.size(); ++i) {
653
654
655
656
    CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)
        << "UnionCsr requires both CSRMatrix have same number of rows";
    CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)
        << "UnionCsr requires both CSRMatrix have same number of cols";
657
658
659
660
661
662
663
664
665
666
    CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
  }

  ATEN_CSR_SWITCH(csrs[0], XPU, IdType, "UnionCsr", {
    ret = impl::UnionCsr<XPU, IdType>(csrs);
  });
  return ret;
}

667
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {
668
669
670
671
672
673
674
675
676
  std::tuple<CSRMatrix, IdArray, IdArray> ret;

  CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRToSimple", {
    ret = impl::CSRToSimple<XPU, IdType>(sorted_csr);
  });
  return ret;
}

677
678
///////////////////////// COO routines //////////////////////////

679
680
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
  bool ret = false;
681
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
682
683
684
685
686
687
688
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
  NDArray ret;
689
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
690
691
692
693
694
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

695
696
bool COOHasDuplicate(COOMatrix coo) {
  bool ret = false;
697
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOHasDuplicate", {
698
699
700
701
702
    ret = impl::COOHasDuplicate<XPU, IdType>(coo);
  });
  return ret;
}

703
704
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
  int64_t ret = 0;
705
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
706
707
708
709
710
711
712
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
  NDArray ret;
713
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
714
715
716
717
718
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

719
720
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row) {
721
  std::pair<NDArray, NDArray> ret;
722
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
723
    ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
724
725
726
727
728
729
730
  });
  return ret;
}

std::vector<NDArray> COOGetDataAndIndices(
    COOMatrix coo, NDArray rows, NDArray cols) {
  std::vector<NDArray> ret;
731
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetDataAndIndices", {
732
    ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);
733
734
735
736
  });
  return ret;
}

737
738
739
740
741
742
743
744
NDArray COOGetData(COOMatrix coo, NDArray rows, NDArray cols) {
  NDArray ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
    ret = impl::COOGetData<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

745
COOMatrix COOTranspose(COOMatrix coo) {
746
  return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);
747
748
}

749
750
CSRMatrix COOToCSR(COOMatrix coo) {
  CSRMatrix ret;
751
  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
752
753
    ATEN_ID_TYPE_SWITCH(
        coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });
754
755
756
757
  });
  return ret;
}

758
759
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
  COOMatrix ret;
760
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
761
    ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);
762
763
764
765
766
767
  });
  return ret;
}

COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
  COOMatrix ret;
768
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
769
    ret = impl::COOSliceRows<XPU, IdType>(coo, rows);
770
771
772
773
774
775
  });
  return ret;
}

COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
  COOMatrix ret;
776
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceMatrix", {
777
778
779
780
781
    ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

782
void COOSort_(COOMatrix* mat, bool sort_column) {
783
  if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;
784
785
786
  ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", {
    ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {
      impl::COOSort_<XPU, IdType>(mat, sort_column);
787
    });
788
  });
789
790
791
}

std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
792
  if (coo.row->shape[0] <= 1) return {true, true};
793
794
795
796
  std::pair<bool, bool> ret;
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", {
    ret = impl::COOIsSorted<XPU, IdType>(coo);
  });
797
798
799
  return ret;
}

800
801
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
802
803
804
805
806
807
808
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
    ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
  });
  return ret;
}

809
810
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
  COOMatrix ret;
811
  ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
812
813
814
815
816
    ret = impl::COORemove<XPU, IdType>(coo, entries);
  });
  return ret;
}

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
std::pair<COOMatrix, FloatArray> COOLaborSampling(
    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
    int importance_sampling, IdArray random_seed, IdArray NIDs) {
  std::pair<COOMatrix, FloatArray> ret;
  ATEN_COO_SWITCH(mat, XPU, IdType, "COOLaborSampling", {
    const auto dtype = IsNullArray(prob)
                           ? DGLDataTypeTraits<float>::dtype
                           : prob->dtype;
    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
      ret = impl::COOLaborSampling<XPU, IdType, FloatType>(
          mat, rows, num_samples, prob, importance_sampling, random_seed, NIDs);
    });
  });
  return ret;
}

833
COOMatrix COORowWiseSampling(
834
835
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
836
  COOMatrix ret;
837
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
838
    if (IsNullArray(prob_or_mask)) {
839
840
      ret = impl::COORowWiseSamplingUniform<XPU, IdType>(
          mat, rows, num_samples, replace);
841
    } else {
842
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
843
          prob_or_mask->dtype, DType, "probability or mask", {
844
845
846
            ret = impl::COORowWiseSampling<XPU, IdType, DType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
847
848
849
850
851
    }
  });
  return ret;
}

852
COOMatrix COORowWisePerEtypeSampling(
853
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
854
855
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
856
  COOMatrix ret;
857
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
858
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
859
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
860
      ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
861
          mat, rows, eid2etype_offset, num_samples, replace);
862
    } else {
863
864
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
865
866
867
868
            ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask,
                replace);
          });
869
870
871
872
873
    }
  });
  return ret;
}

874
875
876
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
  COOMatrix ret;
877
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseTopk", {
878
879
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::COORowWiseTopk<XPU, IdType, DType>(
880
881
          mat, rows, k, weight, ascending);
    });
882
883
884
885
  });
  return ret;
}

886
887
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
  std::pair<COOMatrix, IdArray> ret;
888
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOCoalesce", {
889
890
891
892
893
    ret = impl::COOCoalesce<XPU, IdType>(coo);
  });
  return ret;
}

894
895
896
897
898
899
900
901
902
903
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
  ATEN_XPU_SWITCH_CUDA(coos[0].row->ctx.device_type, XPU, "DisjointUnionCoo", {
    ATEN_ID_TYPE_SWITCH(coos[0].row->dtype, IdType, {
      ret = impl::DisjointUnionCoo<XPU, IdType>(coos);
    });
  });
  return ret;
}

904
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
905
906
907
908
909
910
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", {
    ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);
  });
  return ret;
}
911
912
913

COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
914
915
  CHECK_GT(coos.size(), 1)
      << "UnionCoo creates a union of multiple COOMatrixes";
916
917
  // sanity check
  for (size_t i = 1; i < coos.size(); ++i) {
918
919
920
921
    CHECK_EQ(coos[0].num_rows, coos[i].num_rows)
        << "UnionCoo requires both COOMatrix have same number of rows";
    CHECK_EQ(coos[0].num_cols, coos[i].num_cols)
        << "UnionCoo requires both COOMatrix have same number of cols";
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
  }

  // we assume the number of coos is not large in common cases
  std::vector<IdArray> coo_row;
  std::vector<IdArray> coo_col;
  bool has_data = false;

  for (size_t i = 0; i < coos.size(); ++i) {
    coo_row.push_back(coos[i].row);
    coo_col.push_back(coos[i].col);
    has_data |= COOHasData(coos[i]);
  }

  IdArray row = Concat(coo_row);
  IdArray col = Concat(coo_col);
  IdArray data = NullArray();

  if (has_data) {
    std::vector<IdArray> eid_data;
943
944
945
946
947
    eid_data.push_back(
        COOHasData(coos[0]) ? coos[0].data
                            : Range(
                                  0, coos[0].row->shape[0],
                                  coos[0].row->dtype.bits, coos[0].row->ctx));
948
949
    int64_t num_edges = coos[0].row->shape[0];
    for (size_t i = 1; i < coos.size(); ++i) {
950
951
952
953
954
955
      eid_data.push_back(
          COOHasData(coos[i])
              ? coos[i].data + num_edges
              : Range(
                    num_edges, num_edges + coos[i].row->shape[0],
                    coos[i].row->dtype.bits, coos[i].row->ctx));
956
957
958
959
960
961
962
      num_edges += coos[i].row->shape[0];
    }

    data = Concat(eid_data);
  }

  return COOMatrix(
963
      coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);
964
965
}

966
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {
967
968
  // coo column sorted
  const COOMatrix sorted_coo = COOSort(coo, true);
969
970
971
972
973
974
975
976
977
  const IdArray eids_shuffled =
      COOHasData(sorted_coo)
          ? sorted_coo.data
          : Range(
                0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,
                sorted_coo.row->ctx);
  const auto& coalesced_result = COOCoalesce(sorted_coo);
  const COOMatrix& coalesced_adj = coalesced_result.first;
  const IdArray& count = coalesced_result.second;
978

979
  /**
980
981
   * eids_shuffled actually already contains the mapping from old edge space to
   * the new one:
982
   *
983
984
985
986
987
988
   * * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced
   * into new edge #0.
   * * eids_shuffled[count[0]:count[0] + count[1]] indicates those that
   * coalesced into new edge #1.
   * * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]
   * indicates those that coalesced into new edge #2.
989
990
   * * etc.
   *
991
992
993
   * Here, we need to translate eids_shuffled to an array "eids_remapped" such
   * that eids_remapped[i] indicates the new edge ID the old edge #i is mapped
   * to.  The translation can simply be achieved by (in numpy code):
994
995
996
997
998
999
   *
   *     new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)
   *     eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)
   *     eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled
   */
  const IdArray new_eids = Range(
1000
1001
      0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,
      coalesced_adj.row->ctx);
1002
1003
1004
  const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);

  COOMatrix ret = COOMatrix(
1005
1006
      coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,
      coalesced_adj.col, NullArray(), true, true);
1007
1008
1009
  return std::make_tuple(ret, count, eids_remapped);
}

1010
///////////////////////// Graph Traverse routines //////////////////////////
1011
1012
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1013
1014
1015
1016
1017
1018
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1029
1030
1031
1032
1033
1034
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
  Frontiers ret;
1045
1046
1047
1048
1049
1050
1051
1052
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
  ATEN_XPU_SWITCH(
      csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
        ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
          ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
        });
      });
1053
1054
1055
1056
1057
  return ret;
}

Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1058
1059
1060
1061
1062
1063
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1064
1065
1066
1067
1068
1069
1070
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
    });
  });
  return ret;
}
1071

1072
1073
1074
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels) {
1075
  Frontiers ret;
1076
1077
1078
1079
1080
1081
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1082
1083
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
1084
1085
      ret = impl::DGLDFSLabeledEdges<XPU, IdType>(
          csr, source, has_reverse_edge, has_nontree_edge, return_labels);
1086
1087
1088
1089
1090
    });
  });
  return ret;
}

1091
1092
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
1093
1094
1095
1096
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->format;
    });
1097
1098

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows")
1099
1100
1101
1102
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_rows;
    });
1103
1104

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols")
1105
1106
1107
1108
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_cols;
    });
1109
1110

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices")
1111
1112
1113
1114
1115
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      const int64_t i = args[1];
      *rv = spmat->indices[i];
    });
1116
1117

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags")
1118
1119
1120
1121
1122
1123
1124
1125
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      List<Value> flags;
      for (bool flg : spmat->flags) {
        flags.push_back(Value(MakeValue(flg)));
      }
      *rv = flags;
    });
1126
1127

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
1128
1129
1130
1131
1132
1133
1134
1135
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t format = args[0];
      const int64_t nrows = args[1];
      const int64_t ncols = args[2];
      const List<Value> indices = args[3];
      const List<Value> flags = args[4];
      std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
          format, nrows, ncols, ListValueToVector<IdArray>(indices),
1136
          ListValueToVector<bool>(flags)));
1137
1138
      *rv = SparseMatrixRef(spmat);
    });
1139

1140
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
1141
1142
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const std::string name = args[0];
1143
#ifndef _WIN32
1144
      *rv = SharedMemory::Exist(name);
1145
#else
1146
      *rv = false;
1147
#endif  // _WIN32
1148
    });
1149

1150
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
1151
1152
1153
1154
1155
1156
1157
1158
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray array = args[0];
      CHECK_EQ(array->dtype.code, kDGLUInt);
      std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
      DGLDataType dtype = array->dtype;
      dtype.code = kDGLInt;
      *rv = array.CreateView(shape, dtype, 0);
    });
1159

1160
1161
}  // namespace aten
}  // namespace dgl
1162

1163
std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {
1164
1165
  return os << dgl::aten::ToDebugString(array);
}