test_spmat_coo.cc 14.9 KB
Newer Older
1
#include <dgl/array.h>
2
3
4
#include <dmlc/omp.h>
#include <gtest/gtest.h>

5
6
7
8
9
10
11
12
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

template <typename IDX>
13
aten::CSRMatrix CSR1(DGLContext ctx = CTX) {
14
15
16
17
18
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
19
20
  return aten::CSRMatrix(
      4, 5,
21
22
23
24
25
26
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX) * 8, ctx),
27
      false);
28
29
30
}

template <typename IDX>
31
aten::CSRMatrix CSR2(DGLContext ctx = CTX) {
32
33
34
35
36
37
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
38
39
  return aten::CSRMatrix(
      4, 5,
40
41
42
43
44
45
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
46
      false);
47
48
49
}

template <typename IDX>
50
aten::COOMatrix COO1(DGLContext ctx = CTX) {
51
52
53
54
55
56
57
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
  // row : [0, 2, 0, 1, 2]
  // col : [1, 2, 2, 0, 3]
58
59
  return aten::COOMatrix(
      4, 5,
60
61
62
63
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX) * 8, ctx));
64
65
66
}

template <typename IDX>
67
aten::COOMatrix COO2(DGLContext ctx = CTX) {
68
69
70
71
72
73
74
75
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [1, 2, 2, 0, 3, 2]
76
77
  return aten::COOMatrix(
      4, 5,
78
79
80
81
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX) * 8, ctx));
82
83
}

84
template <typename IDX>
85
aten::CSRMatrix SR_CSR3(DGLContext ctx) {
86
87
88
89
90
91
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
92
93
94
95
96
97
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
98
99
100
101
      false);
}

template <typename IDX>
102
aten::CSRMatrix SRC_CSR3(DGLContext ctx) {
103
104
105
106
107
108
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
109
110
111
112
113
114
      aten::VecToIdArray(
          std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx),
115
116
117
118
      false);
}

template <typename IDX>
119
aten::COOMatrix COO3(DGLContext ctx) {
120
121
122
123
124
125
126
127
128
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  return aten::COOMatrix(
      4, 5,
129
130
131
132
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX) * 8, ctx));
133
134
}

135
136
137
138
139
struct SparseCOOCSR {
  static constexpr uint64_t NUM_ROWS = 100;
  static constexpr uint64_t NUM_COLS = 150;
  static constexpr uint64_t NUM_NZ = 5;
  template <typename IDX>
140
  static aten::COOMatrix COOSparse(const DGLContext &ctx = CTX) {
141
142
143
144
145
146
    return aten::COOMatrix(
        NUM_ROWS, NUM_COLS,
        aten::VecToIdArray(
            std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx));
147
148
149
  }

  template <typename IDX>
150
  static aten::CSRMatrix CSRSparse(const DGLContext &ctx = CTX) {
151
152
153
154
155
    auto &&indptr = std::vector<IDX>(NUM_ROWS + 1, NUM_NZ);
    for (size_t i = 0; i < NUM_NZ; ++i) {
      indptr[i + 1] = static_cast<IDX>(i + 1);
    }
    indptr[0] = 0;
156
157
158
159
160
161
162
    return aten::CSRMatrix(
        NUM_ROWS, NUM_COLS, aten::VecToIdArray(indptr, sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 2, 3, 4, 5}), sizeof(IDX) * 8, ctx),
        aten::VecToIdArray(
            std::vector<IDX>({1, 1, 1, 1, 1}), sizeof(IDX) * 8, ctx),
        false);
163
164
165
  }
};

166
167
168
bool isSparseCOO(
    const int64_t &num_threads, const int64_t &num_nodes,
    const int64_t &num_edges) {
169
170
171
  // refer to COOToCSR<>() in ~dgl/src/array/cpu/spmat_op_impl_coo for details.
  return num_threads * num_nodes > 4 * num_edges;
}
172
173

template <typename IDX>
174
aten::COOMatrix RowSorted_NullData_COO(DGLContext ctx = CTX) {
175
176
177
178
179
180
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 0, 1, 2, 2]
  // col : [1, 2, 0, 2, 3]
181
182
183
184
185
186
187
  return aten::COOMatrix(
      4, 5,
      aten::VecToIdArray(
          std::vector<IDX>({0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::NullArray(), true, false);
188
189
190
}

template <typename IDX>
191
aten::CSRMatrix RowSorted_NullData_CSR(DGLContext ctx = CTX) {
192
193
194
195
196
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4]
197
198
199
200
201
202
203
204
205
  return aten::CSRMatrix(
      4, 5,
      aten::VecToIdArray(
          std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx),
      aten::VecToIdArray(
          std::vector<IDX>({0, 1, 2, 3, 4}), sizeof(IDX) * 8, ctx),
      false);
206
}
207
}  // namespace
208
209

template <typename IDX>
210
void _TestCOOToCSR(DGLContext ctx) {
211
212
  auto coo = COO1<IDX>(ctx);
  auto csr = CSR1<IDX>(ctx);
213
  auto tcsr = aten::COOToCSR(coo);
214
215
216
217
218
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_FALSE(
      isSparseCOO(omp_get_num_threads(), coo.num_rows, coo.row->shape[0]));
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
219
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
220
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
221

222
223
  coo = COO2<IDX>(ctx);
  csr = CSR2<IDX>(ctx);
224
225
226
227
  tcsr = aten::COOToCSR(coo);
  ASSERT_EQ(coo.num_rows, csr.num_rows);
  ASSERT_EQ(coo.num_cols, csr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
228

229
  // Convert from row sorted coo
230
  coo = COO1<IDX>(ctx);
231
  auto rs_coo = aten::COOSort(coo, false);
232
  auto rs_csr = CSR1<IDX>(ctx);
233
  auto rs_tcsr = aten::COOToCSR(rs_coo);
234
  ASSERT_TRUE(rs_coo.row_sorted);
235
236
237
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
238
239
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
240

241
  coo = COO3<IDX>(ctx);
242
  rs_coo = aten::COOSort(coo, false);
243
  rs_csr = SR_CSR3<IDX>(ctx);
244
245
246
247
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
248
249
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
  rs_coo = RowSorted_NullData_COO<IDX>(ctx);
  ASSERT_TRUE(rs_coo.row_sorted);
  rs_csr = RowSorted_NullData_CSR<IDX>(ctx);
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(rs_csr.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_EQ(rs_csr.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_coo.col, rs_tcsr.indices));
  ASSERT_FALSE(ArrayEQ<IDX>(rs_coo.data, rs_tcsr.data));

265
  // Convert from col sorted coo
266
  coo = COO1<IDX>(ctx);
267
  auto src_coo = aten::COOSort(coo, true);
268
  auto src_csr = CSR1<IDX>(ctx);
269
270
271
  auto src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
272
273
274
275
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
276

277
  coo = COO3<IDX>(ctx);
278
  src_coo = aten::COOSort(coo, true);
279
  src_csr = SRC_CSR3<IDX>(ctx);
280
281
282
  src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
283
284
285
286
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
287
288
289
290
291
292
293
294
295
296
297

  coo = SparseCOOCSR::COOSparse<IDX>(ctx);
  csr = SparseCOOCSR::CSRSparse<IDX>(ctx);
  tcsr = aten::COOToCSR(coo);
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_TRUE(
      isSparseCOO(omp_get_num_threads(), coo.num_rows, coo.row->shape[0]));
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
298
299
}

300
TEST(SpmatTest, COOToCSR) {
301
302
303
304
  _TestCOOToCSR<int32_t>(CPU);
  _TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOToCSR<int32_t>(GPU);
305
  _TestCOOToCSR<int64_t>(GPU);
306
#endif
307
308
309
310
311
312
313
314
315
316
317
318
319
320
}

template <typename IDX>
void _TestCOOHasDuplicate() {
  auto csr = COO1<IDX>();
  ASSERT_FALSE(aten::COOHasDuplicate(csr));
  csr = COO2<IDX>();
  ASSERT_TRUE(aten::COOHasDuplicate(csr));
}

TEST(SpmatTest, TestCOOHasDuplicate) {
  _TestCOOHasDuplicate<int32_t>();
  _TestCOOHasDuplicate<int64_t>();
}
321
322

template <typename IDX>
323
void _TestCOOSort(DGLContext ctx) {
324
  auto coo = COO3<IDX>(ctx);
325

326
327
328
  auto sr_coo = COOSort(coo, false);
  ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
  ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
329
330
331
332
333
334
335
  ASSERT_TRUE(sr_coo.row_sorted);
  auto flags = COOIsSorted(sr_coo);
  ASSERT_TRUE(flags.first);
  flags = COOIsSorted(coo);  // original coo should stay the same
  ASSERT_FALSE(flags.first);
  ASSERT_FALSE(flags.second);

336
337
338
  auto src_coo = COOSort(coo, true);
  ASSERT_EQ(coo.num_rows, src_coo.num_rows);
  ASSERT_EQ(coo.num_cols, src_coo.num_cols);
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
  ASSERT_TRUE(src_coo.row_sorted);
  ASSERT_TRUE(src_coo.col_sorted);
  flags = COOIsSorted(src_coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);

  // sort inplace
  COOSort_(&coo);
  ASSERT_TRUE(coo.row_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  COOSort_(&coo, true);
  ASSERT_TRUE(coo.row_sorted);
  ASSERT_TRUE(coo.col_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

  // COO3
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4, 5]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  // Row Sorted
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [2, 1, 2, 0, 2, 3]
  // Row Col Sorted
  // data: [2, 0, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [1, 2, 2, 0, 2, 3]
  auto sort_row = aten::VecToIdArray(
374
      std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX) * 8, ctx);
375
  auto sort_col = aten::VecToIdArray(
376
      std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX) * 8, ctx);
377
  auto sort_col_data = aten::VecToIdArray(
378
      std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX) * 8, ctx);
379
380
381
382
383
384
385

  ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.col, sort_col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
}

386
TEST(SpmatTest, COOSort) {
387
388
389
390
  _TestCOOSort<int32_t>(CPU);
  _TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOSort<int32_t>(GPU);
391
  _TestCOOSort<int64_t>(GPU);
392
#endif
393
}
Da Zheng's avatar
Da Zheng committed
394
395
396
397

template <typename IDX>
void _TestCOOReorder() {
  auto coo = COO2<IDX>();
398
399
  auto new_row =
      aten::VecToIdArray(std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX) * 8, CTX);
Da Zheng's avatar
Da Zheng committed
400
  auto new_col = aten::VecToIdArray(
401
      std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX) * 8, CTX);
Da Zheng's avatar
Da Zheng committed
402
403
404
405
406
407
408
409
410
  auto new_coo = COOReorder(coo, new_row, new_col);
  ASSERT_EQ(new_coo.num_rows, coo.num_rows);
  ASSERT_EQ(new_coo.num_cols, coo.num_cols);
}

TEST(SpmatTest, TestCOOReorder) {
  _TestCOOReorder<int32_t>();
  _TestCOOReorder<int64_t>();
}
411
412

template <typename IDX>
413
void _TestCOOGetData(DGLContext ctx) {
414
415
416
  auto coo = COO2<IDX>(ctx);
  // test get all data
  auto x = aten::COOGetAllData(coo, 0, 0);
417
  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX) * 8, ctx);
418
419
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::COOGetAllData(coo, 0, 2);
420
  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX) * 8, ctx);
421
422
423
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data
424
425
426
427
  auto r =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);
  auto c =
      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
428
  x = aten::COOGetData(coo, r, c);
429
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
430
431
432
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data on sorted
433
434
435
  coo = aten::COOSort(coo);
  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
436
  x = aten::COOGetData(coo, r, c);
437
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
438
439
440
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data w/ broadcasting
441
442
  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX) * 8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, ctx);
443
  x = aten::COOGetData(coo, r, c);
444
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX) * 8, ctx);
445
446
447
448
449
450
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, COOGetData) {
  _TestCOOGetData<int32_t>(CPU);
  _TestCOOGetData<int64_t>(CPU);
451
  //#ifdef DGL_USE_CUDA
452
453
  //_TestCOOGetData<int32_t>(GPU);
  //_TestCOOGetData<int64_t>(GPU);
454
  //#endif
455
456
457
458
459
}

template <typename IDX>
void _TestCOOGetDataAndIndices() {
  auto csr = COO2<IDX>();
460
461
462
463
  auto r =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);
  auto c =
      aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX) * 8, CTX);
464
  auto x = aten::COOGetDataAndIndices(csr, r, c);
465
466
467
468
469
470
  auto tr =
      aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX) * 8, CTX);
  auto tc =
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX) * 8, CTX);
  auto td =
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX) * 8, CTX);
471
472
473
474
475
476
477
478
479
  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}

TEST(SpmatTest, COOGetDataAndIndices) {
  _TestCOOGetDataAndIndices<int32_t>();
  _TestCOOGetDataAndIndices<int64_t>();
}