test_spmat_coo.cc 15.1 KB
Newer Older
1
#include <gtest/gtest.h>
2
#include <dmlc/omp.h>
3
4
5
6
7
8
9
10
11
#include <dgl/array.h>
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

template <typename IDX>
12
aten::CSRMatrix CSR1(DLContext ctx = CTX) {
13
14
15
16
17
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
18
19
  return aten::CSRMatrix(
      4, 5,
20
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
21
      aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
22
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX)*8, ctx),
23
      false);
24
25
26
}

template <typename IDX>
27
aten::CSRMatrix CSR2(DLContext ctx = CTX) {
28
29
30
31
32
33
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
34
35
  return aten::CSRMatrix(
      4, 5,
36
37
38
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
39
      false);
40
41
42
}

template <typename IDX>
43
aten::COOMatrix COO1(DLContext ctx = CTX) {
44
45
46
47
48
49
50
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
  // row : [0, 2, 0, 1, 2]
  // col : [1, 2, 2, 0, 3]
51
52
  return aten::COOMatrix(
      4, 5,
53
54
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, ctx));
55
56
57
}

template <typename IDX>
58
aten::COOMatrix COO2(DLContext ctx = CTX) {
59
60
61
62
63
64
65
66
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [1, 2, 2, 0, 3, 2]
67
68
  return aten::COOMatrix(
      4, 5,
69
70
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, ctx));
71
72
}

73
template <typename IDX>
74
aten::CSRMatrix SR_CSR3(DLContext ctx) {
75
76
77
78
79
80
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
81
82
83
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
84
85
86
87
      false);
}

template <typename IDX>
88
aten::CSRMatrix SRC_CSR3(DLContext ctx) {
89
90
91
92
93
94
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
95
96
97
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
98
99
100
101
      false);
}

template <typename IDX>
102
aten::COOMatrix COO3(DLContext ctx) {
103
104
105
106
107
108
109
110
111
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  return aten::COOMatrix(
      4, 5,
112
113
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX)*8, ctx));
114
115
}

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
struct SparseCOOCSR {
  static constexpr uint64_t NUM_ROWS = 100;
  static constexpr uint64_t NUM_COLS = 150;
  static constexpr uint64_t NUM_NZ = 5;
  template <typename IDX>
  static aten::COOMatrix COOSparse(const DLContext &ctx = CTX) {
    return aten::COOMatrix(NUM_ROWS, NUM_COLS,
                           aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 3, 4}),
                                              sizeof(IDX) * 8, ctx),
                           aten::VecToIdArray(std::vector<IDX>({1, 2, 3, 4, 5}),
                                              sizeof(IDX) * 8, ctx));
  }

  template <typename IDX>
  static aten::CSRMatrix CSRSparse(const DLContext &ctx = CTX) {
    auto &&indptr = std::vector<IDX>(NUM_ROWS + 1, NUM_NZ);
    for (size_t i = 0; i < NUM_NZ; ++i) {
      indptr[i + 1] = static_cast<IDX>(i + 1);
    }
    indptr[0] = 0;
    return aten::CSRMatrix(NUM_ROWS, NUM_COLS,
                           aten::VecToIdArray(indptr, sizeof(IDX) * 8, ctx),
                           aten::VecToIdArray(std::vector<IDX>({1, 2, 3, 4, 5}),
                                              sizeof(IDX) * 8, ctx),
                           aten::VecToIdArray(std::vector<IDX>({1, 1, 1, 1, 1}),
                                              sizeof(IDX) * 8, ctx),
                           false);
  }
};

bool isSparseCOO(const int64_t &num_threads, const int64_t &num_nodes,
                 const int64_t &num_edges) {
  // refer to COOToCSR<>() in ~dgl/src/array/cpu/spmat_op_impl_coo for details.
  return num_threads * num_nodes > 4 * num_edges;
}
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

template <typename IDX>
aten::COOMatrix RowSorted_NullData_COO(DLContext ctx = CTX) {
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 0, 1, 2, 2]
  // col : [1, 2, 0, 2, 3]
  return aten::COOMatrix(4, 5,
                         aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 2, 2}),
                                            sizeof(IDX) * 8, ctx),
                         aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}),
                                            sizeof(IDX) * 8, ctx),
                         aten::NullArray(), true, false);
}

template <typename IDX>
aten::CSRMatrix RowSorted_NullData_CSR(DLContext ctx = CTX) {
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4]
  return aten::CSRMatrix(4, 5,
                         aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}),
                                            sizeof(IDX) * 8, ctx),
                         aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}),
                                            sizeof(IDX) * 8, ctx),
                         aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 3, 4}),
                                            sizeof(IDX) * 8, ctx),
                         false);
}
184
}  // namespace
185
186

template <typename IDX>
187
188
189
void _TestCOOToCSR(DLContext ctx) {
  auto coo = COO1<IDX>(ctx);
  auto csr = CSR1<IDX>(ctx);
190
  auto tcsr = aten::COOToCSR(coo);
191
192
193
194
195
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_FALSE(
      isSparseCOO(omp_get_num_threads(), coo.num_rows, coo.row->shape[0]));
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
196
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
197
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
198

199
200
  coo = COO2<IDX>(ctx);
  csr = CSR2<IDX>(ctx);
201
202
203
204
  tcsr = aten::COOToCSR(coo);
  ASSERT_EQ(coo.num_rows, csr.num_rows);
  ASSERT_EQ(coo.num_cols, csr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
205

206
  // Convert from row sorted coo
207
  coo = COO1<IDX>(ctx);
208
  auto rs_coo = aten::COOSort(coo, false);
209
  auto rs_csr = CSR1<IDX>(ctx);
210
  auto rs_tcsr = aten::COOToCSR(rs_coo);
211
  ASSERT_TRUE(rs_coo.row_sorted);
212
213
214
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
215
216
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
217

218
  coo = COO3<IDX>(ctx);
219
  rs_coo = aten::COOSort(coo, false);
220
  rs_csr = SR_CSR3<IDX>(ctx);
221
222
223
224
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
225
226
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.indices, rs_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_tcsr.data, rs_coo.data));
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
  rs_coo = RowSorted_NullData_COO<IDX>(ctx);
  ASSERT_TRUE(rs_coo.row_sorted);
  rs_csr = RowSorted_NullData_CSR<IDX>(ctx);
  rs_tcsr = aten::COOToCSR(rs_coo);
  ASSERT_EQ(coo.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(rs_csr.num_rows, rs_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, rs_tcsr.num_cols);
  ASSERT_EQ(rs_csr.num_cols, rs_tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indptr, rs_tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.indices, rs_tcsr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_csr.data, rs_tcsr.data));
  ASSERT_TRUE(ArrayEQ<IDX>(rs_coo.col, rs_tcsr.indices));
  ASSERT_FALSE(ArrayEQ<IDX>(rs_coo.data, rs_tcsr.data));

242
  // Convert from col sorted coo
243
  coo = COO1<IDX>(ctx);
244
  auto src_coo = aten::COOSort(coo, true);
245
  auto src_csr = CSR1<IDX>(ctx);
246
247
248
  auto src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
249
250
251
252
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
253

254
  coo = COO3<IDX>(ctx);
255
  src_coo = aten::COOSort(coo, true);
256
  src_csr = SRC_CSR3<IDX>(ctx);
257
258
259
  src_tcsr = aten::COOToCSR(src_coo);
  ASSERT_EQ(coo.num_rows, src_tcsr.num_rows);
  ASSERT_EQ(coo.num_cols, src_tcsr.num_cols);
260
261
262
263
  ASSERT_TRUE(src_tcsr.sorted);
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indptr, src_csr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.indices, src_coo.col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_tcsr.data, src_coo.data));
264
265
266
267
268
269
270
271
272
273
274

  coo = SparseCOOCSR::COOSparse<IDX>(ctx);
  csr = SparseCOOCSR::CSRSparse<IDX>(ctx);
  tcsr = aten::COOToCSR(coo);
  ASSERT_FALSE(coo.row_sorted);
  ASSERT_TRUE(
      isSparseCOO(omp_get_num_threads(), coo.num_rows, coo.row->shape[0]));
  ASSERT_EQ(csr.num_rows, tcsr.num_rows);
  ASSERT_EQ(csr.num_cols, tcsr.num_cols);
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indptr, tcsr.indptr));
  ASSERT_TRUE(ArrayEQ<IDX>(csr.indices, tcsr.indices));
275
276
}

277
TEST(SpmatTest, COOToCSR) {
278
279
280
281
  _TestCOOToCSR<int32_t>(CPU);
  _TestCOOToCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOToCSR<int32_t>(GPU);
282
  _TestCOOToCSR<int64_t>(GPU);
283
#endif
284
285
286
287
288
289
290
291
292
293
294
295
296
297
}

template <typename IDX>
void _TestCOOHasDuplicate() {
  auto csr = COO1<IDX>();
  ASSERT_FALSE(aten::COOHasDuplicate(csr));
  csr = COO2<IDX>();
  ASSERT_TRUE(aten::COOHasDuplicate(csr));
}

TEST(SpmatTest, TestCOOHasDuplicate) {
  _TestCOOHasDuplicate<int32_t>();
  _TestCOOHasDuplicate<int64_t>();
}
298
299

template <typename IDX>
300
301
void _TestCOOSort(DLContext ctx) {
  auto coo = COO3<IDX>(ctx);
302
  
303
304
305
  auto sr_coo = COOSort(coo, false);
  ASSERT_EQ(coo.num_rows, sr_coo.num_rows);
  ASSERT_EQ(coo.num_cols, sr_coo.num_cols);
306
307
308
309
310
311
312
  ASSERT_TRUE(sr_coo.row_sorted);
  auto flags = COOIsSorted(sr_coo);
  ASSERT_TRUE(flags.first);
  flags = COOIsSorted(coo);  // original coo should stay the same
  ASSERT_FALSE(flags.first);
  ASSERT_FALSE(flags.second);

313
314
315
  auto src_coo = COOSort(coo, true);
  ASSERT_EQ(coo.num_rows, src_coo.num_rows);
  ASSERT_EQ(coo.num_cols, src_coo.num_cols);
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
  ASSERT_TRUE(src_coo.row_sorted);
  ASSERT_TRUE(src_coo.col_sorted);
  flags = COOIsSorted(src_coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);

  // sort inplace
  COOSort_(&coo);
  ASSERT_TRUE(coo.row_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  COOSort_(&coo, true);
  ASSERT_TRUE(coo.row_sorted);
  ASSERT_TRUE(coo.col_sorted);
  flags = COOIsSorted(coo);
  ASSERT_TRUE(flags.first);
  ASSERT_TRUE(flags.second);
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

  // COO3
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 1, 2, 3, 4, 5]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  // Row Sorted
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [2, 1, 2, 0, 2, 3]
  // Row Col Sorted
  // data: [2, 0, 5, 3, 1, 4]
  // row : [0, 0, 0, 1, 2, 2]
  // col : [1, 2, 2, 0, 2, 3]
  auto sort_row = aten::VecToIdArray(
351
    std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
352
  auto sort_col = aten::VecToIdArray(
353
    std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx);
354
  auto sort_col_data = aten::VecToIdArray(
355
    std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx);
356
357
358
359
360
361
362

  ASSERT_TRUE(ArrayEQ<IDX>(sr_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.row, sort_row));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.col, sort_col));
  ASSERT_TRUE(ArrayEQ<IDX>(src_coo.data, sort_col_data));
}

363
TEST(SpmatTest, COOSort) {
364
365
366
367
  _TestCOOSort<int32_t>(CPU);
  _TestCOOSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCOOSort<int32_t>(GPU);
368
  _TestCOOSort<int64_t>(GPU);
369
#endif
370
}
Da Zheng's avatar
Da Zheng committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

template <typename IDX>
void _TestCOOReorder() {
  auto coo = COO2<IDX>();
  auto new_row = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_col = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_coo = COOReorder(coo, new_row, new_col);
  ASSERT_EQ(new_coo.num_rows, coo.num_rows);
  ASSERT_EQ(new_coo.num_cols, coo.num_cols);
}

TEST(SpmatTest, TestCOOReorder) {
  _TestCOOReorder<int32_t>();
  _TestCOOReorder<int64_t>();
}
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

template <typename IDX>
void _TestCOOGetData(DLContext ctx) {
  auto coo = COO2<IDX>(ctx);
  // test get all data
  auto x = aten::COOGetAllData(coo, 0, 0);
  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::COOGetAllData(coo, 0, 2);
  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::COOGetData(coo, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data on sorted
  coo = aten::COOSort(coo); 
  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::COOGetData(coo, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data w/ broadcasting
  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::COOGetData(coo, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

}

TEST(SpmatTest, COOGetData) {
  _TestCOOGetData<int32_t>(CPU);
  _TestCOOGetData<int64_t>(CPU);
//#ifdef DGL_USE_CUDA
  //_TestCOOGetData<int32_t>(GPU);
  //_TestCOOGetData<int64_t>(GPU);
//#endif
}

template <typename IDX>
void _TestCOOGetDataAndIndices() {
  auto csr = COO2<IDX>();
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, CTX);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
  auto x = aten::COOGetDataAndIndices(csr, r, c);
  auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, CTX);
  auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, CTX);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, CTX);
  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}

TEST(SpmatTest, COOGetDataAndIndices) {
  _TestCOOGetDataAndIndices<int32_t>();
  _TestCOOGetDataAndIndices<int64_t>();
}