array.cc 43.6 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019-2022 by Contributors
4
5
 * @file array/array.cc
 * @brief DGL array utilities implementation
6
7
 */
#include <dgl/array.h>
8
#include <dgl/bcast.h>
9
#include <dgl/graph_traversal.h>
10
11
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
12
#include <dgl/runtime/device_api.h>
13
14
#include <dgl/runtime/shared_mem.h>

15
#include <sstream>
16

17
#include "../c_api_common.h"
sangwzh's avatar
sangwzh committed
18
19
20
#include "arith.h"
#include "array_op.h"
#include "kernel_decl.h"
21

22
using namespace dgl::runtime;
23

24
namespace dgl {
25
26
namespace aten {

27
28
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
  return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
29
30
}

31
32
33
34
FloatArray NewFloatArray(int64_t length, DGLContext ctx, uint8_t nbits) {
  return FloatArray::Empty({length}, DGLDataType{kDGLFloat, nbits, 1}, ctx);
}

35
36
37
38
39
40
IdArray Clone(IdArray arr) {
  IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);
  ret.CopyFrom(arr);
  return ret;
}

41
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
42
  IdArray ret;
43
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
44
45
46
47
48
49
50
51
52
53
54
    if (nbits == 32) {
      ret = impl::Range<XPU, int32_t>(low, high, ctx);
    } else if (nbits == 64) {
      ret = impl::Range<XPU, int64_t>(low, high, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

55
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
56
  IdArray ret;
57
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
58
59
60
61
62
63
64
65
66
67
68
    if (nbits == 32) {
      ret = impl::Full<XPU, int32_t>(val, length, ctx);
    } else if (nbits == 64) {
      ret = impl::Full<XPU, int64_t>(val, length, ctx);
    } else {
      LOG(FATAL) << "Only int32 or int64 is supported.";
    }
  });
  return ret;
}

69
template <typename DType>
70
NDArray Full(DType val, int64_t length, DGLContext ctx) {
71
72
73
74
75
76
77
  NDArray ret;
  ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
    ret = impl::Full<XPU, DType>(val, length, ctx);
  });
  return ret;
}

78
79
80
81
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
82

83
IdArray AsNumBits(IdArray arr, uint8_t bits) {
84
  CHECK(bits == 32 || bits == 64)
85
86
87
88
      << "Invalid ID type. Must be int32 or int64, but got int"
      << static_cast<int>(bits) << ".";
  if (arr->dtype.bits == bits) return arr;
  if (arr.NumElements() == 0) return NewIdArray(arr->shape[0], arr->ctx, bits);
89
  IdArray ret;
90
  ATEN_XPU_SWITCH_CUDA(arr->ctx.device_type, XPU, "AsNumBits", {
91
92
    ATEN_ID_TYPE_SWITCH(
        arr->dtype, IdType, { ret = impl::AsNumBits<XPU, IdType>(arr, bits); });
93
94
95
96
97
98
  });
  return ret;
}

IdArray HStack(IdArray lhs, IdArray rhs) {
  IdArray ret;
99
100
  CHECK_SAME_CONTEXT(lhs, rhs);
  CHECK_SAME_DTYPE(lhs, rhs);
101
102
103
104
105
106
  CHECK_EQ(lhs->shape[0], rhs->shape[0]);
  auto device = runtime::DeviceAPI::Get(lhs->ctx);
  const auto& ctx = lhs->ctx;
  ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
    const int64_t len = lhs->shape[0];
    ret = NewIdArray(2 * len, lhs->ctx, lhs->dtype.bits);
107
108
109
110
111
112
    device->CopyDataFromTo(
        lhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), 0, len * sizeof(IdType), ctx,
        ctx, lhs->dtype);
    device->CopyDataFromTo(
        rhs.Ptr<IdType>(), 0, ret.Ptr<IdType>(), len * sizeof(IdType),
        len * sizeof(IdType), ctx, ctx, lhs->dtype);
Jinjing Zhou's avatar
Jinjing Zhou committed
113
114
115
116
  });
  return ret;
}

117
118
NDArray IndexSelect(NDArray array, IdArray index) {
  NDArray ret;
119
  CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
120
  CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
121
122
123
124
  // if array is not pinned, index has the same context as array
  // if array is pinned, op dispatching depends on the context of index
  CHECK_VALID_CONTEXT(array, index);
  ATEN_XPU_SWITCH_CUDA(index->ctx.device_type, XPU, "IndexSelect", {
125
126
127
128
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        ret = impl::IndexSelect<XPU, DType, IdType>(array, index);
      });
129
130
131
132
133
    });
  });
  return ret;
}

134
template <typename ValueType>
135
ValueType IndexSelect(NDArray array, int64_t index) {
136
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
137
  CHECK(index >= 0 && index < array.NumElements())
138
      << "Index " << index << " is out of bound.";
139
  ValueType ret = 0;
140
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
141
142
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ret = impl::IndexSelect<XPU, DType>(array, index);
143
144
145
146
    });
  });
  return ret;
}
147
148
149
150
151
152
153
154
155
156
template int32_t IndexSelect<int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<uint64_t>(NDArray array, int64_t index);
template float IndexSelect<float>(NDArray array, int64_t index);
template double IndexSelect<double>(NDArray array, int64_t index);

NDArray IndexSelect(NDArray array, int64_t start, int64_t end) {
  CHECK_EQ(array->ndim, 1) << "Only support select values from 1D array.";
  CHECK(start >= 0 && start < array.NumElements())
157
      << "Index " << start << " is out of bound.";
158
  CHECK(end >= 0 && end <= array.NumElements())
159
      << "Index " << end << " is out of bound.";
160
161
162
163
164
  CHECK_LE(start, end);
  auto device = runtime::DeviceAPI::Get(array->ctx);
  const int64_t len = end - start;
  NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
  ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
165
166
167
    device->CopyDataFromTo(
        array->data, start * sizeof(DType), ret->data, 0, len * sizeof(DType),
        array->ctx, ret->ctx, array->dtype);
168
169
170
  });
  return ret;
}
171

172
173
NDArray Scatter(NDArray array, IdArray indices) {
  NDArray ret;
174
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Scatter", {
175
176
177
178
179
180
181
182
183
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(indices->dtype, IdType, {
        ret = impl::Scatter<XPU, DType, IdType>(array, indices);
      });
    });
  });
  return ret;
}

184
185
186
187
188
void Scatter_(IdArray index, NDArray value, NDArray out) {
  CHECK_SAME_DTYPE(value, out);
  CHECK_SAME_CONTEXT(index, value);
  CHECK_SAME_CONTEXT(index, out);
  CHECK_EQ(value->shape[0], index->shape[0]);
189
  if (index->shape[0] == 0) return;
190
191
192
193
194
195
196
197
198
  ATEN_XPU_SWITCH_CUDA(value->ctx.device_type, XPU, "Scatter_", {
    ATEN_DTYPE_SWITCH(value->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
        impl::Scatter_<XPU, DType, IdType>(index, value, out);
      });
    });
  });
}

199
200
NDArray Repeat(NDArray array, IdArray repeats) {
  NDArray ret;
201
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Repeat", {
202
203
204
205
206
207
208
209
210
    ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
      ATEN_ID_TYPE_SWITCH(repeats->dtype, IdType, {
        ret = impl::Repeat<XPU, DType, IdType>(array, repeats);
      });
    });
  });
  return ret;
}

211
212
IdArray Relabel_(const std::vector<IdArray>& arrays) {
  IdArray ret;
213
  ATEN_XPU_SWITCH_CUDA(arrays[0]->ctx.device_type, XPU, "Relabel_", {
214
215
216
217
218
219
220
    ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
      ret = impl::Relabel_<XPU, IdType>(arrays);
    });
  });
  return ret;
}

221
222
223
224
225
226
227
228
229
230
NDArray Concat(const std::vector<IdArray>& arrays) {
  IdArray ret;

  int64_t len = 0, offset = 0;
  for (size_t i = 0; i < arrays.size(); ++i) {
    len += arrays[i]->shape[0];
    CHECK_SAME_DTYPE(arrays[0], arrays[i]);
    CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
  }

231
  NDArray ret_arr = NDArray::Empty({len}, arrays[0]->dtype, arrays[0]->ctx);
232
233
234
235
236

  auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
  for (size_t i = 0; i < arrays.size(); ++i) {
    ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
      device->CopyDataFromTo(
237
238
239
240
241
242
          static_cast<DType*>(arrays[i]->data), 0,
          static_cast<DType*>(ret_arr->data), offset,
          arrays[i]->shape[0] * sizeof(DType), arrays[i]->ctx, ret_arr->ctx,
          arrays[i]->dtype);

      offset += arrays[i]->shape[0] * sizeof(DType);
243
244
245
246
247
248
    });
  }

  return ret_arr;
}

249
template <typename ValueType>
250
251
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
  std::tuple<NDArray, IdArray, IdArray> ret;
252
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "Pack", {
253
254
255
256
257
258
259
260
261
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ret = impl::Pack<XPU, DType>(array, static_cast<DType>(pad_value));
    });
  });
  return ret;
}

template std::tuple<NDArray, IdArray, IdArray> Pack<int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<int64_t>(NDArray, int64_t);
262
263
264
265
template std::tuple<NDArray, IdArray, IdArray> Pack<uint32_t>(
    NDArray, uint32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<uint64_t>(
    NDArray, uint64_t);
266
267
268
269
270
template std::tuple<NDArray, IdArray, IdArray> Pack<float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<double>(NDArray, double);

std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
  std::pair<NDArray, IdArray> ret;
271
  ATEN_XPU_SWITCH(array->ctx.device_type, XPU, "ConcatSlices", {
272
273
274
275
276
277
278
279
280
    ATEN_DTYPE_SWITCH(array->dtype, DType, "array", {
      ATEN_ID_TYPE_SWITCH(lengths->dtype, IdType, {
        ret = impl::ConcatSlices<XPU, DType, IdType>(array, lengths);
      });
    });
  });
  return ret;
}

281
282
283
284
285
286
287
288
289
290
IdArray CumSum(IdArray array, bool prepend_zero) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "CumSum", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
      ret = impl::CumSum<XPU, IdType>(array, prepend_zero);
    });
  });
  return ret;
}

291
292
293
IdArray NonZero(NDArray array) {
  IdArray ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "NonZero", {
294
295
    ATEN_ID_TYPE_SWITCH(
        array->dtype, DType, { ret = impl::NonZero<XPU, DType>(array); });
296
297
298
299
  });
  return ret;
}

300
std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
301
302
303
304
305
306
307
  if (array.NumElements() == 0) {
    IdArray idx = NewIdArray(0, array->ctx, 64);
    return std::make_pair(array, idx);
  }
  std::pair<IdArray, IdArray> ret;
  ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
    ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
308
      ret = impl::Sort<XPU, IdType>(array, num_bits);
309
310
311
312
313
    });
  });
  return ret;
}

314
315
std::string ToDebugString(NDArray array) {
  std::ostringstream oss;
316
  NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
317
318
319
320
321
322
  oss << "array([";
  ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
    for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
      oss << a.Ptr<DType>()[i] << ", ";
    }
  });
323
  if (a.NumElements() > 10) oss << "...";
324
325
326
327
  oss << "], dtype=" << array->dtype << ", ctx=" << array->ctx << ")";
  return oss.str();
}

328
329
330
///////////////////////// CSR routines //////////////////////////

bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
331
332
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
  CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
333
  bool ret = false;
334
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsNonZero", {
335
336
337
338
339
340
341
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
  NDArray ret;
342
343
  CHECK_SAME_DTYPE(csr.indices, row);
  CHECK_SAME_DTYPE(csr.indices, col);
344
345
  CHECK_SAME_CONTEXT(row, col);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRIsNonZero", {
346
347
348
349
350
351
352
    ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
  });
  return ret;
}

bool CSRHasDuplicate(CSRMatrix csr) {
  bool ret = false;
353
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRHasDuplicate", {
354
355
356
357
358
359
    ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
  });
  return ret;
}

int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
360
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
361
  int64_t ret = 0;
362
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowNNZ", {
363
364
365
366
367
368
369
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
  NDArray ret;
370
  CHECK_SAME_DTYPE(csr.indices, row);
371
  ATEN_CSR_SWITCH_CUDA_UVA(csr, row, XPU, IdType, "CSRGetRowNNZ", {
372
373
374
375
376
377
    ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
378
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
379
  NDArray ret;
380
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowColumnIndices", {
381
382
383
384
385
386
    ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
  });
  return ret;
}

NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
387
  CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
388
  NDArray ret;
389
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGetRowData", {
390
    ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
391
392
393
394
  });
  return ret;
}

395
bool CSRIsSorted(CSRMatrix csr) {
396
  if (csr.indices->shape[0] <= 1) return true;
397
398
399
400
401
402
403
  bool ret = false;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRIsSorted", {
    ret = impl::CSRIsSorted<XPU, IdType>(csr);
  });
  return ret;
}

404
405
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
  NDArray ret;
406
407
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
408
409
  CHECK_SAME_CONTEXT(rows, cols);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
410
    ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
411
412
413
414
  });
  return ret;
}

415
template <typename DType>
416
417
NDArray CSRGetData(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, DType filler) {
418
419
420
  NDArray ret;
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
421
422
423
  CHECK_SAME_CONTEXT(rows, cols);
  CHECK_SAME_CONTEXT(rows, weights);
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetData", {
424
425
    ret =
        impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols, weights, filler);
426
427
428
429
  });
  return ret;
}

430
431
432
433
434
435
436
437
438
439
440
441
runtime::NDArray CSRGetFloatingData(
    CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
    runtime::NDArray weights, double filler) {
  if (weights->dtype.bits == 64) {
    return CSRGetData<double>(csr, rows, cols, weights, filler);
  } else {
    CHECK(weights->dtype.bits == 32)
        << "CSRGetFloatingData only supports 32 or 64 bits floaring number";
    return CSRGetData<float>(csr, rows, cols, weights, filler);
  }
}

442
443
444
445
446
template NDArray CSRGetData<float>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, float filler);
template NDArray CSRGetData<double>(
    CSRMatrix csr, NDArray rows, NDArray cols, NDArray weights, double filler);

447
448
std::vector<NDArray> CSRGetDataAndIndices(
    CSRMatrix csr, NDArray rows, NDArray cols) {
449
450
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
451
  CHECK_SAME_CONTEXT(rows, cols);
452
  std::vector<NDArray> ret;
453
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRGetDataAndIndices", {
454
    ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
455
456
457
458
459
460
  });
  return ret;
}

CSRMatrix CSRTranspose(CSRMatrix csr) {
  CSRMatrix ret;
461
462
463
464
  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRTranspose", {
    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
      ret = impl::CSRTranspose<XPU, IdType>(csr);
    });
465
466
467
468
469
470
471
  });
  return ret;
}

COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
  COOMatrix ret;
  if (data_as_order) {
472
473
474
475
476
477
    ATEN_XPU_SWITCH_CUDA(
        csr.indptr->ctx.device_type, XPU, "CSRToCOODataAsOrder", {
          ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
            ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
          });
        });
478
  } else {
479
    ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "CSRToCOO", {
480
481
482
483
484
485
486
487
488
      ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
        ret = impl::CSRToCOO<XPU, IdType>(csr);
      });
    });
  }
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
489
490
491
  CHECK(start >= 0 && start < csr.num_rows) << "Invalid start index: " << start;
  CHECK(end >= 0 && end <= csr.num_rows) << "Invalid end index: " << end;
  CHECK_GE(end, start);
492
  CSRMatrix ret;
493
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRSliceRows", {
494
    ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
495
496
497
498
499
  });
  return ret;
}

CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
500
  CHECK_SAME_DTYPE(csr.indices, rows);
501
  CSRMatrix ret;
502
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceRows", {
503
    ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
504
505
506
507
508
  });
  return ret;
}

CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
509
510
  CHECK_SAME_DTYPE(csr.indices, rows);
  CHECK_SAME_DTYPE(csr.indices, cols);
511
  CHECK_SAME_CONTEXT(rows, cols);
512
  CSRMatrix ret;
513
  ATEN_CSR_SWITCH_CUDA_UVA(csr, rows, XPU, IdType, "CSRSliceMatrix", {
514
    ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
515
516
517
518
  });
  return ret;
}

519
void CSRSort_(CSRMatrix* csr) {
520
521
522
  if (csr->sorted) return;
  ATEN_CSR_SWITCH_CUDA(
      *csr, XPU, IdType, "CSRSort_", { impl::CSRSort_<XPU, IdType>(csr); });
Da Zheng's avatar
Da Zheng committed
523
524
}

525
std::pair<CSRMatrix, NDArray> CSRSortByTag(
526
    const CSRMatrix& csr, IdArray tag, int64_t num_tags) {
527
  CHECK_EQ(csr.indices->shape[0], tag->shape[0])
528
529
      << "The length of the tag array should be equal to the number of "
         "non-zero data.";
530
531
532
533
534
535
536
537
538
539
540
  CHECK_SAME_CONTEXT(csr.indices, tag);
  CHECK_INT(tag, "tag");
  std::pair<CSRMatrix, NDArray> ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRSortByTag", {
    ATEN_ID_TYPE_SWITCH(tag->dtype, TagType, {
      ret = impl::CSRSortByTag<XPU, IdType, TagType>(csr, tag, num_tags);
    });
  });
  return ret;
}

541
542
CSRMatrix CSRReorder(
    CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
Da Zheng's avatar
Da Zheng committed
543
544
545
546
547
548
549
  CSRMatrix ret;
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRReorder", {
    ret = impl::CSRReorder<XPU, IdType>(csr, new_row_ids, new_col_ids);
  });
  return ret;
}

550
551
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
  CSRMatrix ret;
552
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRRemove", {
553
554
555
556
557
    ret = impl::CSRRemove<XPU, IdType>(csr, entries);
  });
  return ret;
}

558
559
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
    CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
560
561
    int importance_sampling, IdArray random_seed, float seed2_contribution,
    IdArray NIDs) {
562
563
  std::pair<COOMatrix, FloatArray> ret;
  ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRLaborSampling", {
564
565
    const auto dtype =
        IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;
566
567
    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
      ret = impl::CSRLaborSampling<XPU, IdType, FloatType>(
568
569
          mat, rows, num_samples, prob, importance_sampling, random_seed,
          seed2_contribution, NIDs);
570
571
572
573
574
    });
  });
  return ret;
}

575
COOMatrix CSRRowWiseSampling(
576
577
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
578
  COOMatrix ret;
579
  if (IsNullArray(prob_or_mask)) {
580
581
582
583
584
    ATEN_CSR_SWITCH_CUDA_UVA(
        mat, rows, XPU, IdType, "CSRRowWiseSamplingUniform", {
          ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(
              mat, rows, num_samples, replace);
        });
585
  } else {
586
587
    // prob_or_mask is pinned and rows on GPU is valid
    CHECK_VALID_CONTEXT(prob_or_mask, rows);
588
    ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRRowWiseSampling", {
sangwzh's avatar
sangwzh committed
589
      CHECK(!(prob_or_mask->dtype.bits == 8 && (XPU == kDGLCUDA || XPU == kDGLROCM)))
590
          << "GPU sampling with masks is currently not supported yet.";
591
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
592
          prob_or_mask->dtype, FloatType, "probability or mask", {
593
594
595
            ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
596
597
    });
  }
598
599
600
  return ret;
}

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
template <typename IdType, bool map_seed_nodes>
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
    CSRMatrix mat, IdArray rows, IdArray seed_mapping,
    std::vector<IdType>* new_seed_nodes, int64_t num_samples,
    NDArray prob_or_mask, bool replace) {
  std::pair<CSRMatrix, IdArray> ret;
  if (IsNullArray(prob_or_mask)) {
    ATEN_XPU_SWITCH(
        rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", {
          ret =
              impl::CSRRowWiseSamplingUniformFused<XPU, IdType, map_seed_nodes>(
                  mat, rows, seed_mapping, new_seed_nodes, num_samples,
                  replace);
        });
  } else {
    CHECK_VALID_CONTEXT(prob_or_mask, rows);
    ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", {
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask->dtype, FloatType, "probability or mask", {
            ret = impl::CSRRowWiseSamplingFused<
                XPU, IdType, FloatType, map_seed_nodes>(
                mat, rows, seed_mapping, new_seed_nodes, num_samples,
                prob_or_mask, replace);
          });
    });
  }
  return ret;
}

template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);

template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, true>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);

template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, false>(
    CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);

642
COOMatrix CSRRowWisePerEtypeSampling(
643
    CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
644
645
646
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace,
    bool rowwise_etype_sorted) {
647
  COOMatrix ret;
648
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
649
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWisePerEtypeSampling", {
650
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
651
      ret = impl::CSRRowWisePerEtypeSamplingUniform<XPU, IdType>(
652
653
          mat, rows, eid2etype_offset, num_samples, replace,
          rowwise_etype_sorted);
654
    } else {
655
656
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
657
658
659
660
            ret = impl::CSRRowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask, replace,
                rowwise_etype_sorted);
          });
661
662
663
664
665
    }
  });
  return ret;
}

666
COOMatrix CSRRowWiseTopk(
667
    CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
668
  COOMatrix ret;
669
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseTopk", {
670
671
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::CSRRowWiseTopk<XPU, IdType, DType>(
672
673
674
675
676
677
          mat, rows, k, weight, ascending);
    });
  });
  return ret;
}

678
COOMatrix CSRRowWiseSamplingBiased(
679
680
    CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray tag_offset,
    FloatArray bias, bool replace) {
681
682
683
  COOMatrix ret;
  ATEN_CSR_SWITCH(mat, XPU, IdType, "CSRRowWiseSamplingBiased", {
    ATEN_FLOAT_TYPE_SWITCH(bias->dtype, FloatType, "bias", {
684
      ret = impl::CSRRowWiseSamplingBiased<XPU, IdType, FloatType>(
685
686
687
688
689
690
          mat, rows, num_samples, tag_offset, bias, replace);
    });
  });
  return ret;
}

691
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
692
693
    const CSRMatrix& csr, int64_t num_samples, int num_trials,
    bool exclude_self_loops, bool replace, double redundancy) {
694
695
696
697
698
699
700
701
702
703
  CHECK_GT(num_samples, 0) << "Number of samples must be positive";
  CHECK_GT(num_trials, 0) << "Number of sampling trials must be positive";
  std::pair<IdArray, IdArray> result;
  ATEN_CSR_SWITCH_CUDA(csr, XPU, IdType, "CSRGlobalUniformNegativeSampling", {
    result = impl::CSRGlobalUniformNegativeSampling<XPU, IdType>(
        csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
  });
  return result;
}

704
705
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
  CSRMatrix ret;
706
707
  CHECK_GT(csrs.size(), 1)
      << "UnionCsr creates a union of multiple CSRMatrixes";
708
709
  // sanity check
  for (size_t i = 1; i < csrs.size(); ++i) {
710
711
712
713
    CHECK_EQ(csrs[0].num_rows, csrs[i].num_rows)
        << "UnionCsr requires both CSRMatrix have same number of rows";
    CHECK_EQ(csrs[0].num_cols, csrs[i].num_cols)
        << "UnionCsr requires both CSRMatrix have same number of cols";
714
715
716
717
718
719
720
721
722
723
    CHECK_SAME_CONTEXT(csrs[0].indptr, csrs[i].indptr);
    CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
  }

  ATEN_CSR_SWITCH(csrs[0], XPU, IdType, "UnionCsr", {
    ret = impl::UnionCsr<XPU, IdType>(csrs);
  });
  return ret;
}

724
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr) {
725
726
727
728
729
730
731
732
733
  std::tuple<CSRMatrix, IdArray, IdArray> ret;

  CSRMatrix sorted_csr = (CSRIsSorted(csr)) ? csr : CSRSort(csr);
  ATEN_CSR_SWITCH(csr, XPU, IdType, "CSRToSimple", {
    ret = impl::CSRToSimple<XPU, IdType>(sorted_csr);
  });
  return ret;
}

734
735
///////////////////////// COO routines //////////////////////////

736
737
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
  bool ret = false;
738
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
739
740
741
742
743
744
745
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
  NDArray ret;
746
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOIsNonZero", {
747
748
749
750
751
    ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
  });
  return ret;
}

752
753
bool COOHasDuplicate(COOMatrix coo) {
  bool ret = false;
754
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOHasDuplicate", {
755
756
757
758
759
    ret = impl::COOHasDuplicate<XPU, IdType>(coo);
  });
  return ret;
}

760
761
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
  int64_t ret = 0;
762
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
763
764
765
766
767
768
769
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
  NDArray ret;
770
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOGetRowNNZ", {
771
772
773
774
775
    ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
  });
  return ret;
}

776
777
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
    COOMatrix coo, int64_t row) {
778
  std::pair<NDArray, NDArray> ret;
779
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetRowDataAndIndices", {
780
    ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
781
782
783
784
785
786
787
  });
  return ret;
}

std::vector<NDArray> COOGetDataAndIndices(
    COOMatrix coo, NDArray rows, NDArray cols) {
  std::vector<NDArray> ret;
788
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetDataAndIndices", {
789
    ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);
790
791
792
793
  });
  return ret;
}

794
795
796
797
798
799
800
801
NDArray COOGetData(COOMatrix coo, NDArray rows, NDArray cols) {
  NDArray ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOGetData", {
    ret = impl::COOGetData<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

802
COOMatrix COOTranspose(COOMatrix coo) {
803
  return COOMatrix(coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data);
804
805
}

806
807
CSRMatrix COOToCSR(COOMatrix coo) {
  CSRMatrix ret;
808
  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "COOToCSR", {
809
810
    ATEN_ID_TYPE_SWITCH(
        coo.row->dtype, IdType, { ret = impl::COOToCSR<XPU, IdType>(coo); });
811
812
813
814
  });
  return ret;
}

815
816
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
  COOMatrix ret;
817
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
818
    ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);
819
820
821
822
823
824
  });
  return ret;
}

COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
  COOMatrix ret;
825
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceRows", {
826
    ret = impl::COOSliceRows<XPU, IdType>(coo, rows);
827
828
829
830
831
832
  });
  return ret;
}

COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
  COOMatrix ret;
833
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOSliceMatrix", {
834
835
836
837
838
    ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);
  });
  return ret;
}

839
void COOSort_(COOMatrix* mat, bool sort_column) {
840
  if ((mat->row_sorted && !sort_column) || mat->col_sorted) return;
841
842
843
  ATEN_XPU_SWITCH_CUDA(mat->row->ctx.device_type, XPU, "COOSort_", {
    ATEN_ID_TYPE_SWITCH(mat->row->dtype, IdType, {
      impl::COOSort_<XPU, IdType>(mat, sort_column);
844
    });
845
  });
846
847
848
}

std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
849
  if (coo.row->shape[0] <= 1) return {true, true};
850
851
852
853
  std::pair<bool, bool> ret;
  ATEN_COO_SWITCH_CUDA(coo, XPU, IdType, "COOIsSorted", {
    ret = impl::COOIsSorted<XPU, IdType>(coo);
  });
854
855
856
  return ret;
}

857
858
COOMatrix COOReorder(
    COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids) {
859
860
861
862
863
864
865
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOReorder", {
    ret = impl::COOReorder<XPU, IdType>(coo, new_row_ids, new_col_ids);
  });
  return ret;
}

866
867
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
  COOMatrix ret;
868
  ATEN_COO_SWITCH(coo, XPU, IdType, "COORemove", {
869
870
871
872
873
    ret = impl::COORemove<XPU, IdType>(coo, entries);
  });
  return ret;
}

874
875
std::pair<COOMatrix, FloatArray> COOLaborSampling(
    COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob,
876
877
    int importance_sampling, IdArray random_seed, float seed2_contribution,
    IdArray NIDs) {
878
879
  std::pair<COOMatrix, FloatArray> ret;
  ATEN_COO_SWITCH(mat, XPU, IdType, "COOLaborSampling", {
880
881
    const auto dtype =
        IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;
882
883
    ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
      ret = impl::COOLaborSampling<XPU, IdType, FloatType>(
884
885
          mat, rows, num_samples, prob, importance_sampling, random_seed,
          seed2_contribution, NIDs);
886
887
888
889
890
    });
  });
  return ret;
}

891
COOMatrix COORowWiseSampling(
892
893
    COOMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
    bool replace) {
894
  COOMatrix ret;
895
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseSampling", {
896
    if (IsNullArray(prob_or_mask)) {
897
898
      ret = impl::COORowWiseSamplingUniform<XPU, IdType>(
          mat, rows, num_samples, replace);
899
    } else {
900
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
901
          prob_or_mask->dtype, DType, "probability or mask", {
902
903
904
            ret = impl::COORowWiseSampling<XPU, IdType, DType>(
                mat, rows, num_samples, prob_or_mask, replace);
          });
905
906
907
908
909
    }
  });
  return ret;
}

910
COOMatrix COORowWisePerEtypeSampling(
911
    COOMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
912
913
    const std::vector<int64_t>& num_samples,
    const std::vector<NDArray>& prob_or_mask, bool replace) {
914
  COOMatrix ret;
915
  CHECK(prob_or_mask.size() > 0) << "probability or mask array is empty";
916
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWisePerEtypeSampling", {
917
    if (std::all_of(prob_or_mask.begin(), prob_or_mask.end(), IsNullArray)) {
918
      ret = impl::COORowWisePerEtypeSamplingUniform<XPU, IdType>(
919
          mat, rows, eid2etype_offset, num_samples, replace);
920
    } else {
921
922
      ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
          prob_or_mask[0]->dtype, DType, "probability or mask", {
923
924
925
926
            ret = impl::COORowWisePerEtypeSampling<XPU, IdType, DType>(
                mat, rows, eid2etype_offset, num_samples, prob_or_mask,
                replace);
          });
927
928
929
930
931
    }
  });
  return ret;
}

932
933
934
COOMatrix COORowWiseTopk(
    COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
  COOMatrix ret;
935
  ATEN_COO_SWITCH(mat, XPU, IdType, "COORowWiseTopk", {
936
937
    ATEN_DTYPE_SWITCH(weight->dtype, DType, "weight", {
      ret = impl::COORowWiseTopk<XPU, IdType, DType>(
938
939
          mat, rows, k, weight, ascending);
    });
940
941
942
943
  });
  return ret;
}

944
945
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
  std::pair<COOMatrix, IdArray> ret;
946
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOCoalesce", {
947
948
949
950
951
    ret = impl::COOCoalesce<XPU, IdType>(coo);
  });
  return ret;
}

952
953
954
955
956
957
958
959
960
961
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
  ATEN_XPU_SWITCH_CUDA(coos[0].row->ctx.device_type, XPU, "DisjointUnionCoo", {
    ATEN_ID_TYPE_SWITCH(coos[0].row->dtype, IdType, {
      ret = impl::DisjointUnionCoo<XPU, IdType>(coos);
    });
  });
  return ret;
}

962
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
963
964
965
966
967
968
  COOMatrix ret;
  ATEN_COO_SWITCH(coo, XPU, IdType, "COOLineGraph", {
    ret = impl::COOLineGraph<XPU, IdType>(coo, backtracking);
  });
  return ret;
}
969
970
971

COOMatrix UnionCoo(const std::vector<COOMatrix>& coos) {
  COOMatrix ret;
972
973
  CHECK_GT(coos.size(), 1)
      << "UnionCoo creates a union of multiple COOMatrixes";
974
975
  // sanity check
  for (size_t i = 1; i < coos.size(); ++i) {
976
977
978
979
    CHECK_EQ(coos[0].num_rows, coos[i].num_rows)
        << "UnionCoo requires both COOMatrix have same number of rows";
    CHECK_EQ(coos[0].num_cols, coos[i].num_cols)
        << "UnionCoo requires both COOMatrix have same number of cols";
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
    CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
  }

  // we assume the number of coos is not large in common cases
  std::vector<IdArray> coo_row;
  std::vector<IdArray> coo_col;
  bool has_data = false;

  for (size_t i = 0; i < coos.size(); ++i) {
    coo_row.push_back(coos[i].row);
    coo_col.push_back(coos[i].col);
    has_data |= COOHasData(coos[i]);
  }

  IdArray row = Concat(coo_row);
  IdArray col = Concat(coo_col);
  IdArray data = NullArray();

  if (has_data) {
    std::vector<IdArray> eid_data;
1001
1002
1003
1004
1005
    eid_data.push_back(
        COOHasData(coos[0]) ? coos[0].data
                            : Range(
                                  0, coos[0].row->shape[0],
                                  coos[0].row->dtype.bits, coos[0].row->ctx));
1006
1007
    int64_t num_edges = coos[0].row->shape[0];
    for (size_t i = 1; i < coos.size(); ++i) {
1008
1009
1010
1011
1012
1013
      eid_data.push_back(
          COOHasData(coos[i])
              ? coos[i].data + num_edges
              : Range(
                    num_edges, num_edges + coos[i].row->shape[0],
                    coos[i].row->dtype.bits, coos[i].row->ctx));
1014
1015
1016
1017
1018
1019
1020
      num_edges += coos[i].row->shape[0];
    }

    data = Concat(eid_data);
  }

  return COOMatrix(
1021
      coos[0].num_rows, coos[0].num_cols, row, col, data, false, false);
1022
1023
}

1024
std::tuple<COOMatrix, IdArray, IdArray> COOToSimple(const COOMatrix& coo) {
1025
1026
  // coo column sorted
  const COOMatrix sorted_coo = COOSort(coo, true);
1027
1028
1029
1030
1031
1032
1033
1034
1035
  const IdArray eids_shuffled =
      COOHasData(sorted_coo)
          ? sorted_coo.data
          : Range(
                0, sorted_coo.row->shape[0], sorted_coo.row->dtype.bits,
                sorted_coo.row->ctx);
  const auto& coalesced_result = COOCoalesce(sorted_coo);
  const COOMatrix& coalesced_adj = coalesced_result.first;
  const IdArray& count = coalesced_result.second;
1036

1037
  /**
1038
1039
   * eids_shuffled actually already contains the mapping from old edge space to
   * the new one:
1040
   *
1041
1042
1043
1044
1045
1046
   * * eids_shuffled[0:count[0]] indicates the original edge IDs that coalesced
   * into new edge #0.
   * * eids_shuffled[count[0]:count[0] + count[1]] indicates those that
   * coalesced into new edge #1.
   * * eids_shuffled[count[0] + count[1]:count[0] + count[1] + count[2]]
   * indicates those that coalesced into new edge #2.
1047
1048
   * * etc.
   *
1049
1050
1051
   * Here, we need to translate eids_shuffled to an array "eids_remapped" such
   * that eids_remapped[i] indicates the new edge ID the old edge #i is mapped
   * to.  The translation can simply be achieved by (in numpy code):
1052
1053
1054
1055
1056
1057
   *
   *     new_eid_for_eids_shuffled = np.range(len(count)).repeat(count)
   *     eids_remapped = np.zeros_like(new_eid_for_eids_shuffled)
   *     eids_remapped[eids_shuffled] = new_eid_for_eids_shuffled
   */
  const IdArray new_eids = Range(
1058
1059
      0, coalesced_adj.row->shape[0], coalesced_adj.row->dtype.bits,
      coalesced_adj.row->ctx);
1060
1061
1062
  const IdArray eids_remapped = Scatter(Repeat(new_eids, count), eids_shuffled);

  COOMatrix ret = COOMatrix(
1063
1064
      coalesced_adj.num_rows, coalesced_adj.num_cols, coalesced_adj.row,
      coalesced_adj.col, NullArray(), true, true);
1065
1066
1067
  return std::make_tuple(ret, count, eids_remapped);
}

1068
///////////////////////// Graph Traverse routines //////////////////////////
1069
1070
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1071
1072
1073
1074
1075
1076
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSNodesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSNodesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1087
1088
1089
1090
1091
1092
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "BFSEdgesFrontiers", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::BFSEdgesFrontiers<XPU, IdType>(csr, source);
    });
  });
  return ret;
}

Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
  Frontiers ret;
1103
1104
1105
1106
1107
1108
1109
1110
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
  ATEN_XPU_SWITCH(
      csr.indptr->ctx.device_type, XPU, "TopologicalNodesFrontiers", {
        ATEN_ID_TYPE_SWITCH(csr.indices->dtype, IdType, {
          ret = impl::TopologicalNodesFrontiers<XPU, IdType>(csr);
        });
      });
1111
1112
1113
1114
1115
  return ret;
}

Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
  Frontiers ret;
1116
1117
1118
1119
1120
1121
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1122
1123
1124
1125
1126
1127
1128
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
      ret = impl::DGLDFSEdges<XPU, IdType>(csr, source);
    });
  });
  return ret;
}
1129

1130
1131
1132
Frontiers DGLDFSLabeledEdges(
    const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
    const bool has_nontree_edge, const bool return_labels) {
1133
  Frontiers ret;
1134
1135
1136
1137
1138
1139
  CHECK_EQ(csr.indptr->ctx.device_type, source->ctx.device_type)
      << "Graph and source should in the same device context";
  CHECK_EQ(csr.indices->dtype, source->dtype)
      << "Graph and source should in the same dtype";
  CHECK_EQ(csr.num_rows, csr.num_cols)
      << "Graph traversal can only work on square-shaped CSR.";
1140
1141
  ATEN_XPU_SWITCH(source->ctx.device_type, XPU, "DGLDFSLabeledEdges", {
    ATEN_ID_TYPE_SWITCH(source->dtype, IdType, {
1142
1143
      ret = impl::DGLDFSLabeledEdges<XPU, IdType>(
          csr, source, has_reverse_edge, has_nontree_edge, return_labels);
1144
1145
1146
1147
1148
    });
  });
  return ret;
}

1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
void CSRSpMM(
    const std::string& op, const std::string& reduce, const CSRMatrix& csr,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
  const auto& bcast = CalcBcastOff(op, ufeat, efeat);

  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "SpMM", {
    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
        SpMMCsr<XPU, IdType, Dtype>(
            op, reduce, bcast, csr, ufeat, efeat, out, out_aux);
      });
    });
  });
}

void CSRSpMM(
    const char* op, const char* reduce, const CSRMatrix& csr, NDArray ufeat,
    NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
  CSRSpMM(
      std::string(op), std::string(reduce), csr, ufeat, efeat, out, out_aux);
}

void CSRSDDMM(
    const std::string& op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, int lhs_target, int rhs_target) {
  const auto& bcast = CalcBcastOff(op, ufeat, efeat);

  ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "SDDMM", {
    ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
        SDDMMCsr<XPU, IdType, Dtype>(
            op, bcast, csr, ufeat, efeat, out, lhs_target, rhs_target);
      });
    });
  });
}

void CSRSDDMM(
    const char* op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
    NDArray out, int lhs_target, int rhs_target) {
  return CSRSDDMM(
      std::string(op), csr, ufeat, efeat, out, lhs_target, rhs_target);
}

void COOSpMM(
    const std::string& op, const std::string& reduce, const COOMatrix& coo,
    NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
  const auto& bcast = CalcBcastOff(op, ufeat, efeat);

  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "SpMM", {
    ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
        SpMMCoo<XPU, IdType, Dtype>(
            op, reduce, bcast, coo, ufeat, efeat, out, out_aux);
      });
    });
  });
}

void COOSpMM(
    const char* op, const char* reduce, const COOMatrix& coo, NDArray ufeat,
    NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
  COOSpMM(
      std::string(op), std::string(reduce), coo, ufeat, efeat, out, out_aux);
}

void COOSDDMM(
    const std::string& op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out, int lhs_target, int rhs_target) {
  const auto& bcast = CalcBcastOff(op, ufeat, efeat);

  ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "SDDMM", {
    ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
      ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
        SDDMMCoo<XPU, IdType, Dtype>(
            op, bcast, coo, ufeat, efeat, out, lhs_target, rhs_target);
      });
    });
  });
}

void COOSDDMM(
    const char* op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
    NDArray out, int lhs_target, int rhs_target) {
  COOSDDMM(std::string(op), coo, ufeat, efeat, out, lhs_target, rhs_target);
}

1236
1237
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
1238
1239
1240
1241
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->format;
    });
1242
1243

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumRows")
1244
1245
1246
1247
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_rows;
    });
1248
1249

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetNumCols")
1250
1251
1252
1253
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      *rv = spmat->num_cols;
    });
1254
1255

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetIndices")
1256
1257
1258
1259
1260
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      const int64_t i = args[1];
      *rv = spmat->indices[i];
    });
1261
1262

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFlags")
1263
1264
1265
1266
1267
1268
1269
1270
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      SparseMatrixRef spmat = args[0];
      List<Value> flags;
      for (bool flg : spmat->flags) {
        flags.push_back(Value(MakeValue(flg)));
      }
      *rv = flags;
    });
1271
1272

DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLCreateSparseMatrix")
1273
1274
1275
1276
1277
1278
1279
1280
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const int32_t format = args[0];
      const int64_t nrows = args[1];
      const int64_t ncols = args[2];
      const List<Value> indices = args[3];
      const List<Value> flags = args[4];
      std::shared_ptr<SparseMatrix> spmat(new SparseMatrix(
          format, nrows, ncols, ListValueToVector<IdArray>(indices),
1281
          ListValueToVector<bool>(flags)));
1282
1283
      *rv = SparseMatrixRef(spmat);
    });
1284

1285
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
1286
1287
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      const std::string name = args[0];
1288
#ifndef _WIN32
1289
      *rv = SharedMemory::Exist(name);
1290
#else
1291
      *rv = false;
1292
#endif  // _WIN32
1293
    });
1294

1295
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
1296
1297
1298
1299
1300
1301
1302
1303
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray array = args[0];
      CHECK_EQ(array->dtype.code, kDGLUInt);
      std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
      DGLDataType dtype = array->dtype;
      dtype.code = kDGLInt;
      *rv = array.CreateView(shape, dtype, 0);
    });
1304

1305
1306
}  // namespace aten
}  // namespace dgl
1307

1308
std::ostream& operator<<(std::ostream& os, dgl::runtime::NDArray array) {
1309
1310
  return os << dgl::aten::ToDebugString(array);
}