test_spmat_coo.cc 18.4 KB
Newer Older
1
#include <dgl/array.h>
2
3
#include <dmlc/omp.h>
#include <gtest/gtest.h>
4
5
6
#include <omp.h>

#include <random>
7

8
9
10
11
12
13
14
15
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

template <typename IDX>
16
aten::CSRMatrix CSR1(DGLContext ctx = CTX) {
17
18
19
20
21
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
22
23
  return aten::CSRMatrix(
      4, 5,
24
25
26
27
28
29
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX) * 8, ctx),
30
      false);
31
32
33
}

template <typename IDX>
34
aten::CSRMatrix CSR2(DGLContext ctx = CTX) {
35
36
37
38
39
40
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
41
42
  return aten::CSRMatrix(
      4, 5,
43
44
45
46
47
48
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
49
      false);
50
51
52
}

template <typename IDX>
53
aten::COOMatrix COO1(DGLContext ctx = CTX) {
54
55
56
57
58
59
60
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
  // row : [0, 2, 0, 1, 2]
  // col : [1, 2, 2, 0, 3]
61
62
  return aten::COOMatrix(
      4, 5,
63
64
65
66
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX) * 8, ctx));
67
68
69
}

template <typename IDX>
70
aten::COOMatrix COO2(DGLContext ctx = CTX) {
71
72
73
74
75
76
77
78
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [1, 2, 2, 0, 3, 2]
79
80
  return aten::COOMatrix(
      4, 5,
81
82
83
84
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX) * 8, ctx));
85
86
}

87
template <typename IDX>
88
aten::CSRMatrix SR_CSR3(DGLContext ctx) {
89
90
91
92
93
94
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
95
96
97
98
99
100
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
101
102
103
104
      false);
}

template <typename IDX>
105
aten::CSRMatrix SRC_CSR3(DGLContext ctx) {
106
107
108
109
110
111
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
112
113
114
115
116
117
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
118
119
120
121
      false);
}

template <typename IDX>
122
aten::COOMatrix COO3(DGLContext ctx) {
123
124
125
126
127
128
129
130
131
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  return aten::COOMatrix(
      4, 5,
132
133
134
135
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX) * 8, ctx));
136
137
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
template <typename IDX>
aten::COOMatrix COORandomized(IDX rows_and_cols, int64_t nnz, int seed) {
  std::vector<IDX> vec_rows(nnz);
  std::vector<IDX> vec_cols(nnz);
  std::vector<IDX> vec_data(nnz);

#pragma omp parallel
  {
    const int64_t num_threads = omp_get_num_threads();
    const int64_t thread_id = omp_get_thread_num();
    const int64_t chunk = nnz / num_threads;
    const int64_t size = (thread_id == num_threads - 1)
                             ? nnz - chunk * (num_threads - 1)
                             : chunk;
    auto rows = vec_rows.data() + thread_id * chunk;
    auto cols = vec_cols.data() + thread_id * chunk;
    auto data = vec_data.data() + thread_id * chunk;

    std::mt19937_64 gen64(seed + thread_id);
    std::mt19937 gen32(seed + thread_id);

    for (int64_t i = 0; i < size; ++i) {
      rows[i] = gen64() % rows_and_cols;
      cols[i] = gen64() % rows_and_cols;
      data[i] = gen32() % 90 + 1;
    }
  }

  return aten::COOMatrix(
      rows_and_cols, rows_and_cols,
      aten::VecToIdArray(vec_rows, sizeof(IDX) * 8, CTX),
      aten::VecToIdArray(vec_cols, sizeof(IDX) * 8, CTX),
      aten::VecToIdArray(vec_data, sizeof(IDX) * 8, CTX), false, false);
}

173
174
175
176
177
struct SparseCOOCSR {
  static constexpr uint64_t NUM_ROWS = 100;
  static constexpr uint64_t NUM_COLS = 150;
  static constexpr uint64_t NUM_NZ = 5;
  template <typename IDX>
178
  static aten::COOMatrix COOSparse(const DGLContext &ctx = CTX) {
179
180
181
182
183
184
    return aten::COOMatrix(
        NUM_ROWS, NUM_COLS,
        aten::VecToIdArray(
            std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx));
185
186
187
  }

  template <typename IDX>
188
  static aten::CSRMatrix CSRSparse(const DGLContext &ctx = CTX) {
189
190
191
192
193
    auto &&indptr = std::vector<IDX>(NUM_ROWS + 1, NUM_NZ);
    for (size_t i = 0; i < NUM_NZ; ++i) {
      indptr[i + 1] = static_cast<IDX>(i + 1);
    }
    indptr[0] = 0;
194
195
196
197
198
199
200
    return aten::CSRMatrix(
        NUM_ROWS, NUM_COLS, aten::VecToIdArray(indptr, sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 1, 1, 1, 1}), sizeof(IDX) * 8, ctx),
        false);
201
202
203
  }
};

204
template <typename IDX>
205
aten::COOMatrix RowSorted_NullData_COO(DGLContext ctx = CTX) {
206
207
208
209
210
211
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 0, 1, 2, 2]
  // col : [1, 2, 0, 2, 3]
212
213
214
215
216
217
218
  return aten::COOMatrix(
      4, 5,
      aten::VecToIdArray(
          std::vector<IDX>({0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::NullArray(), true, false);
219
220
221
}

template <typename IDX>
222
aten::CSRMatrix RowSorted_NullData_CSR(DGLContext ctx = CTX) {
223
224
225
226
227
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4]
228
229
230
231
232
233
234
235
236
  return aten::CSRMatrix(
      4, 5,
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),
      false);
237
}
238
}  // namespace
239
240

template <typename IDX>
241
void _TestCOOToCSR(DGLContext ctx) {
242
243
  auto coo = COO1<IDX>(ctx);
  auto csr = CSR1<IDX>(ctx);
244
  auto tcsr = aten::COOToCSR(coo);
245
246
247
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
248
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
249
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
250

251
252
  coo = COO2<IDX>(ctx);
  csr = CSR2<IDX>(ctx);
253
254
255
256
  tcsr = aten::COOToCSR(coo);
  ASSERT_EQ(coo.num_rows, csr.num_rows);
  ASSERT_EQ(coo.num_cols, csr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
257

258
  // Convert from row sorted coo
259
  coo = COO1<IDX>(ctx);
260
  auto rs_coo = aten::COOSort(coo, false);
261
  auto rs_csr = CSR1<IDX>(ctx);
262
  auto rs_tcsr = aten::COOToCSR(rs_coo);
263
  ASSERT_TRUE(rs_coo.row_sorted);
264
265
266
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
267
268
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
269

270
  coo = COO3<IDX>(ctx);
271
  rs_coo = aten::COOSort(coo, false);
272
  rs_csr = SR_CSR3<IDX>(ctx);
273
274
275
276
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
277
278
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
  rs_coo = RowSorted_NullData_COO<IDX>(ctx);
  ASSERT_TRUE(rs_coo.row_sorted);
  rs_csr = RowSorted_NullData_CSR<IDX>(ctx);
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(rs_csr.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_EQ(rs_csr.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_coo.col, rs_tcsr.indices));
  ASSERT_FALSE(ArrayEQ<IDX>(rs_coo.data, rs_tcsr.data));

294
  // Convert from col sorted coo
295
  coo = COO1<IDX>(ctx);
296
  auto src_coo = aten::COOSort(coo, true);
297
  auto src_csr = CSR1<IDX>(ctx);
298
299
300
  auto src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
301
302
303
304
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
305

306
  coo = COO3<IDX>(ctx);
307
  src_coo = aten::COOSort(coo, true);
308
  src_csr = SRC_CSR3<IDX>(ctx);
309
310
311
  src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
312
313
314
315
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
316
317
318
319
320
321
322
323
324

  coo = SparseCOOCSR::COOSparse<IDX>(ctx);
  csr = SparseCOOCSR::CSRSparse<IDX>(ctx);
  tcsr = aten::COOToCSR(coo);
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
325
326
}

327
TEST(SpmatTest, COOToCSR) {
328
329
330
331
  _TestCOOToCSR<int32_t>(CPU);
  _TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOToCSR<int32_t>(GPU);
332
  _TestCOOToCSR<int64_t>(GPU);
333
#endif
334
335
336
337
338
339
340
341
342
343
344
345
346
347
}

template <typename IDX>
void _TestCOOHasDuplicate() {
  auto csr = COO1<IDX>();
  ASSERT_FALSE(aten::COOHasDuplicate(csr));
  csr = COO2<IDX>();
  ASSERT_TRUE(aten::COOHasDuplicate(csr));
}

TEST(SpmatTest, TestCOOHasDuplicate) {
  _TestCOOHasDuplicate<int32_t>();
  _TestCOOHasDuplicate<int64_t>();
}
348
349

template <typename IDX>
350
void _TestCOOSort(DGLContext ctx) {
351
  auto coo = COO3<IDX>(ctx);
352

353
354
355
  auto sr_coo = COOSort(coo, false);
  ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
  ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
356
357
358
359
360
361
362
  ASSERT_TRUE(sr_coo.row_sorted);
  auto flags = COOIsSorted(sr_coo);
  ASSERT_TRUE(flags.first);
  flags = COOIsSorted(coo);  // original coo should stay the same
  ASSERT_FALSE(flags.first);
  ASSERT_FALSE(flags.second);

363
364
365
  auto src_coo = COOSort(coo, true);
  ASSERT_EQ(coo.num_rows, src_coo.num_rows);
  ASSERT_EQ(coo.num_cols, src_coo.num_cols);
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
  ASSERT_TRUE(src_coo.row_sorted);
  ASSERT_TRUE(src_coo.col_sorted);
  flags = COOIsSorted(src_coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);

  // sort inplace
  COOSort_(&coo);
  ASSERT_TRUE(coo.row_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  COOSort_(&coo, true);
  ASSERT_TRUE(coo.row_sorted);
  ASSERT_TRUE(coo.col_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

  // COO3
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4, 5]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  // Row Sorted
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [2, 1, 2, 0, 2, 3]
  // Row Col Sorted
  // data: [2, 0, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [1, 2, 2, 0, 2, 3]
  auto sort_row = aten::VecToIdArray(
401
      std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx);
402
  auto sort_col = aten::VecToIdArray(
403
      std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx);
404
  auto sort_col_data = aten::VecToIdArray(
405
      std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx);
406
407
408
409
410
411
412

  ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.col, sort_col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
}

413
TEST(SpmatTest, COOSort) {
414
415
416
417
  _TestCOOSort<int32_t>(CPU);
  _TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOSort<int32_t>(GPU);
418
  _TestCOOSort<int64_t>(GPU);
419
#endif
420
}
Da Zheng's avatar
Da Zheng committed
421
422
423
424

template <typename IDX>
void _TestCOOReorder() {
  auto coo = COO2<IDX>();
425
426
  auto new_row =
      aten::VecToIdArray(std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX) * 8, CTX);
Da Zheng's avatar
Da Zheng committed
427
  auto new_col = aten::VecToIdArray(
428
      std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX) * 8, CTX);
Da Zheng's avatar
Da Zheng committed
429
430
431
432
433
434
435
436
437
  auto new_coo = COOReorder(coo, new_row, new_col);
  ASSERT_EQ(new_coo.num_rows, coo.num_rows);
  ASSERT_EQ(new_coo.num_cols, coo.num_cols);
}

TEST(SpmatTest, TestCOOReorder) {
  _TestCOOReorder<int32_t>();
  _TestCOOReorder<int64_t>();
}
438
439

template <typename IDX>
440
void _TestCOOGetData(DGLContext ctx) {
441
442
443
  auto coo = COO2<IDX>(ctx);
  // test get all data
  auto x = aten::COOGetAllData(coo, 0, 0);
444
  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);
445
446
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::COOGetAllData(coo, 0, 2);
447
  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX) * 8, ctx);
448
449
450
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data
451
452
453
454
  auto r =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);
  auto c =
      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
455
  x = aten::COOGetData(coo, r, c);
456
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
457
458
459
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data on sorted
460
461
462
  coo = aten::COOSort(coo);
  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
463
  x = aten::COOGetData(coo, r, c);
464
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
465
466
467
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data w/ broadcasting
468
469
  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX) * 8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
470
  x = aten::COOGetData(coo, r, c);
471
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
472
473
474
475
476
477
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, COOGetData) {
  _TestCOOGetData<int32_t>(CPU);
  _TestCOOGetData<int64_t>(CPU);
478
  // #ifdef DGL_USE_CUDA
479
480
  //_TestCOOGetData<int32_t>(GPU);
  //_TestCOOGetData<int64_t>(GPU);
481
  // #endif
482
483
484
485
486
}

template <typename IDX>
void _TestCOOGetDataAndIndices() {
  auto csr = COO2<IDX>();
487
488
489
490
  auto r =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);
  auto c =
      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, CTX);
491
  auto x = aten::COOGetDataAndIndices(csr, r, c);
492
493
494
495
496
497
  auto tr =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);
  auto tc =
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX) * 8, CTX);
  auto td =
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, CTX);
498
499
500
501
502
503
504
505
506
  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}

TEST(SpmatTest, COOGetDataAndIndices) {
  _TestCOOGetDataAndIndices<int32_t>();
  _TestCOOGetDataAndIndices<int64_t>();
}
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

template <typename IDX>
void _TestCOOToCSRAlgs() {
  // Compare results between different CPU COOToCSR implementations.
  // NNZ is chosen to be bigger than the limit for the "small" matrix algorithm.
  // N is set to lay on border between "sparse" and "dense" algorithm choice.

  const int64_t num_threads = std::min(256, omp_get_max_threads());
  const int64_t min_num_threads = 3;

  if (num_threads < min_num_threads) {
    std::cerr << "[          ] [ INFO ]"
              << "This test requires at least 3 OMP threads to work properly"
              << std::endl;
    GTEST_SKIP();
    return;
  }

  // Select N and NNZ for COO matrix in a way than depending on number of
  // threads different algorithm will be used.
  // See WhichCOOToCSR in src/array/cpu/spmat_op_impl_coo.cc for details
  const int64_t type_scale = sizeof(IDX) >> 1;
  const int64_t small = 50 * num_threads * type_scale * type_scale;
  // NNZ should be bigger than limit for small matrix algorithm
  const int64_t nnz = small + 1234;
  // N is chosen to lay on sparse/dense border
  const int64_t n = type_scale * nnz / num_threads;
  const IDX rows_nad_cols = n + 1;  // should be bigger than sparse/dense border

  // Note that it will be better to set the seed to a random value when gtest
  // allows to use --gtest_random_seed without --gtest_shuffle and report this
  // value for reproduction. This way we can find unforeseen situations and
  // potential bugs.
  const auto seed = 123321;
  auto coo = COORandomized<IDX>(rows_nad_cols, nnz, seed);

  omp_set_num_threads(1);
  // UnSortedSmallCOOToCSR will be used
  auto tcsr_small = aten::COOToCSR(coo);
  ASSERT_EQ(coo.num_rows, tcsr_small.num_rows);
  ASSERT_EQ(coo.num_cols, tcsr_small.num_cols);

  omp_set_num_threads(num_threads - 1);
  // UnSortedDenseCOOToCSR will be used
  auto tcsr_dense = aten::COOToCSR(coo);
  ASSERT_EQ(tcsr_small.num_rows, tcsr_dense.num_rows);
  ASSERT_EQ(tcsr_small.num_cols, tcsr_dense.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indptr, tcsr_dense.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indices, tcsr_dense.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.data, tcsr_dense.data));

  omp_set_num_threads(num_threads);
  // UnSortedSparseCOOToCSR will be used
  auto tcsr_sparse = aten::COOToCSR(coo);
  ASSERT_EQ(tcsr_small.num_rows, tcsr_sparse.num_rows);
  ASSERT_EQ(tcsr_small.num_cols, tcsr_sparse.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indptr, tcsr_sparse.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.indices, tcsr_sparse.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(tcsr_small.data, tcsr_sparse.data));
  return;
}

TEST(SpmatTest, COOToCSRAlgs) {
  _TestCOOToCSRAlgs<int32_t>();
  _TestCOOToCSRAlgs<int64_t>();
}