test_spmat_csr.cc 16.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <gtest/gtest.h>
#include <dgl/array.h>
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

template <typename IDX>
11
aten::CSRMatrix CSR1(DLContext ctx = CTX) {
12
13
14
15
16
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
17
18
  return aten::CSRMatrix(
      4, 5,
19
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
20
21
      aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 3, 2}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX)*8, ctx),
22
      false);
23
24
25
}

template <typename IDX>
26
aten::CSRMatrix CSR2(DLContext ctx = CTX) {
27
28
29
30
31
32
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
33
34
  return aten::CSRMatrix(
      4, 5,
35
36
37
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
38
      false);
39
40
41
}

template <typename IDX>
42
aten::COOMatrix COO1(DLContext ctx = CTX) {
43
44
45
46
47
48
49
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
  // row : [0, 2, 0, 1, 2]
  // col : [1, 2, 2, 0, 3]
50
51
  return aten::COOMatrix(
      4, 5,
52
53
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, ctx));
54
55
56
}

template <typename IDX>
57
aten::COOMatrix COO2(DLContext ctx = CTX) {
58
59
60
61
62
63
64
65
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [1, 2, 2, 0, 3, 2]
66
67
  return aten::COOMatrix(
      4, 5,
68
69
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, ctx));
70
71
}

72
template <typename IDX>
73
aten::CSRMatrix SR_CSR3(DLContext ctx) {
74
75
76
77
78
79
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
80
81
82
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
83
84
85
86
      false);
}

template <typename IDX>
87
aten::CSRMatrix SRC_CSR3(DLContext ctx) {
88
89
90
91
92
93
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
94
95
96
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
97
98
99
100
      false);
}

template <typename IDX>
101
aten::COOMatrix COO3(DLContext ctx) {
102
103
104
105
106
107
108
109
110
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  return aten::COOMatrix(
      4, 5,
111
112
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX)*8, ctx));
113
114
}

115
}  // namespace
116
117

template <typename IDX>
118
119
void _TestCSRIsNonZero(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
120
121
  ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));
  ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));
122
123
  IdArray r = aten::VecToIdArray(std::vector<IDX>({2, 2, 0, 0}), sizeof(IDX)*8, ctx);
  IdArray c = aten::VecToIdArray(std::vector<IDX>({1, 1, 1, 3}), sizeof(IDX)*8, ctx);
124
  IdArray x = aten::CSRIsNonZero(csr, r, c);
125
  IdArray tx = aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 0}), sizeof(IDX)*8, ctx);
126
127
128
129
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRIsNonZero) {
130
131
132
133
134
135
  _TestCSRIsNonZero<int32_t>(CPU);
  _TestCSRIsNonZero<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRIsNonZero<int32_t>(GPU);
  _TestCSRIsNonZero<int64_t>(GPU);
#endif
136
137
138
}

template <typename IDX>
139
140
void _TestCSRGetRowNNZ(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
141
142
  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 0), 3);
  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 3), 0);
143
  IdArray r = aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX)*8, ctx);
144
  IdArray x = aten::CSRGetRowNNZ(csr, r);
145
  IdArray tx = aten::VecToIdArray(std::vector<IDX>({3, 0}), sizeof(IDX)*8, ctx);
146
147
148
149
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowNNZ) {
150
151
152
153
154
155
  _TestCSRGetRowNNZ<int32_t>(CPU);
  _TestCSRGetRowNNZ<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowNNZ<int32_t>(GPU);
  _TestCSRGetRowNNZ<int64_t>(GPU);
#endif
156
157
158
}

template <typename IDX>
159
160
void _TestCSRGetRowColumnIndices(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
161
  auto x = aten::CSRGetRowColumnIndices(csr, 0);
162
  auto tx = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, ctx);
163
164
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowColumnIndices(csr, 1);
165
  tx = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, ctx);
166
167
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowColumnIndices(csr, 3);
168
  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
169
170
171
172
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowColumnIndices) {
173
174
175
176
177
178
  _TestCSRGetRowColumnIndices<int32_t>(CPU);
  _TestCSRGetRowColumnIndices<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowColumnIndices<int32_t>(GPU);
  _TestCSRGetRowColumnIndices<int64_t>(GPU);
#endif
179
180
181
}

template <typename IDX>
182
183
void _TestCSRGetRowData(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
184
  auto x = aten::CSRGetRowData(csr, 0);
185
  auto tx = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
186
187
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowData(csr, 1);
188
  tx = aten::VecToIdArray(std::vector<IDX>({3}), sizeof(IDX)*8, ctx);
189
190
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowData(csr, 3);
191
  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
192
193
194
195
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowData) {
196
197
198
199
200
201
  _TestCSRGetRowData<int32_t>(CPU);
  _TestCSRGetRowData<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowData<int32_t>(GPU);
  _TestCSRGetRowData<int64_t>(GPU);
#endif
202
203
204
}

template <typename IDX>
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
void _TestCSRGetData(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  // test get all data
  auto x = aten::CSRGetAllData(csr, 0, 0);
  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetAllData(csr, 0, 2);
  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::CSRGetData(csr, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
220
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
221
222
223
224
225
226
227

  // test get data on sorted
  csr = aten::CSRSort(csr);
  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::CSRGetData(csr, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
228
229
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

230
231
232
  // test get data w/ broadcasting
  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
233
  x = aten::CSRGetData(csr, r, c);
234
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
235
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
236

237
238
}

239
240
241
242
243
TEST(SpmatTest, CSRGetData) {
  _TestCSRGetData<int32_t>(CPU);
  _TestCSRGetData<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetData<int32_t>(GPU);
244
  _TestCSRGetData<int64_t>(GPU);
245
#endif
246
247
248
}

template <typename IDX>
249
250
251
252
void _TestCSRGetDataAndIndices(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
253
  auto x = aten::CSRGetDataAndIndices(csr, r, c);
254
255
256
  auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
257
258
259
260
261
  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}

262
263
264
265
266
267
268
TEST(SpmatTest, CSRGetDataAndIndices) {
  _TestCSRGetDataAndIndices<int32_t>(CPU);
  _TestCSRGetDataAndIndices<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetDataAndIndices<int32_t>(GPU);
  _TestCSRGetDataAndIndices<int64_t>(GPU);
#endif
269
270
271
}

template <typename IDX>
272
273
void _TestCSRTranspose(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
274
275
276
277
278
279
280
281
282
  auto csr_t = aten::CSRTranspose(csr);
  // [[0, 1, 0, 0],
  //  [1, 0, 0, 0],
  //  [2, 0, 1, 0],
  //  [0, 0, 1, 0],
  //  [0, 0, 0, 0]]
  // data: [3, 0, 2, 5, 1, 4]
  ASSERT_EQ(csr_t.num_rows, 5);
  ASSERT_EQ(csr_t.num_cols, 4);
283
284
285
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, ctx);
286
287
288
289
290
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.data, td));
}

291
TEST(SpmatTest, CSRTranspose) {
292
293
294
295
  _TestCSRTranspose<int32_t>(CPU);
  _TestCSRTranspose<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRTranspose<int32_t>(GPU);
296
  _TestCSRTranspose<int64_t>(GPU);
297
#endif
298
299
300
}

template <typename IDX>
301
302
void _TestCSRToCOO(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
303
304
305
306
  {
  auto coo = CSRToCOO(csr, false);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
307
  ASSERT_TRUE(coo.row_sorted);
308
  auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
309
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
310
311
312
313
314
315
316
317
318
319
320
321
322
323
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, csr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.data, csr.data));

  // convert from sorted csr
  auto s_csr = CSRSort(csr);
  coo = CSRToCOO(s_csr, false);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
  ASSERT_TRUE(coo.row_sorted);
  ASSERT_TRUE(coo.col_sorted);
  tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, s_csr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.data, s_csr.data));
324
325
326
327
328
  }
  {
  auto coo = CSRToCOO(csr, true);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
329
  auto tcoo = COO2<IDX>(ctx);
330
331
332
333
334
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tcoo.row));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tcoo.col));
  }
}

335
TEST(SpmatTest, CSRToCOO) {
336
337
338
339
  _TestCSRToCOO<int32_t>(CPU);
  _TestCSRToCOO<int64_t>(CPU);
#if DGL_USE_CUDA
  _TestCSRToCOO<int32_t>(GPU);
340
  _TestCSRToCOO<int64_t>(GPU);
341
#endif
342
343
344
}

template <typename IDX>
345
346
void _TestCSRSliceRows(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
347
348
349
350
351
352
353
  auto x = aten::CSRSliceRows(csr, 1, 4);
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [3, 1, 4]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 5);
354
355
356
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, ctx);
357
358
359
360
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

361
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
362
363
364
365
366
  x = aten::CSRSliceRows(csr, r);
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3]
367
368
369
  tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, ctx);
  ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, ctx);
  td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, ctx);
370
371
372
373
374
375
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}

TEST(SpmatTest, TestCSRSliceRows) {
376
377
378
379
380
381
  _TestCSRSliceRows<int32_t>(CPU);
  _TestCSRSliceRows<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSliceRows<int32_t>(GPU);
  _TestCSRSliceRows<int64_t>(GPU);
#endif
382
383
384
}

template <typename IDX>
385
386
387
388
389
390
void _TestCSRSliceMatrix(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  {
  // square
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX)*8, ctx);
391
392
393
394
395
396
397
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[1, 2, 0],
  //  [0, 0, 0],
  //  [0, 0, 0]]
  // data: [0, 2, 5]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 3);
398
399
400
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({0, 1, 1}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
401
402
403
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
  }
  {
  // non-square
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[0, 1],
  //  [1, 0],
  //  [0, 0]]
  // data: [0, 3]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 2);
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 2}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
  }
  {
  // empty slice
  auto r = aten::VecToIdArray(std::vector<IDX>({2, 3}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[0, 0],
  //  [0, 0]]
  // data: []
  ASSERT_EQ(x.num_rows, 2);
  ASSERT_EQ(x.num_cols, 2);
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
  }
440
441
}

442
443
444
445
446
TEST(SpmatTest, CSRSliceMatrix) {
  _TestCSRSliceMatrix<int32_t>(CPU);
  _TestCSRSliceMatrix<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSliceMatrix<int32_t>(GPU);
447
  _TestCSRSliceMatrix<int64_t>(GPU);
448
#endif
449
450
451
}

template <typename IDX>
452
453
void _TestCSRHasDuplicate(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
454
  ASSERT_FALSE(aten::CSRHasDuplicate(csr));
455
  csr = CSR2<IDX>(ctx);
456
457
458
  ASSERT_TRUE(aten::CSRHasDuplicate(csr));
}

459
460
461
462
463
TEST(SpmatTest, CSRHasDuplicate) {
  _TestCSRHasDuplicate<int32_t>(CPU);
  _TestCSRHasDuplicate<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRHasDuplicate<int32_t>(GPU);
464
  _TestCSRHasDuplicate<int64_t>(GPU);
465
#endif
466
467
}

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
template <typename IDX>
void _TestCSRSort(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
  ASSERT_FALSE(aten::CSRIsSorted(csr));
  auto csr1 = aten::CSRSort(csr);
  ASSERT_FALSE(aten::CSRIsSorted(csr));
  ASSERT_TRUE(aten::CSRIsSorted(csr1));
  ASSERT_TRUE(csr1.sorted);
  aten::CSRSort_(&csr);
  ASSERT_TRUE(aten::CSRIsSorted(csr));
  ASSERT_TRUE(csr.sorted);
  csr = CSR2<IDX>(ctx);
  ASSERT_TRUE(aten::CSRIsSorted(csr));
}

TEST(SpmatTest, CSRSort) {
  _TestCSRSort<int32_t>(CPU);
  _TestCSRSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSort<int32_t>(GPU);
488
  _TestCSRSort<int64_t>(GPU);
489
490
491
#endif
}

Da Zheng's avatar
Da Zheng committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
template <typename IDX>
void _TestCSRReorder() {
  auto csr = CSR2<IDX>();
  auto new_row = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_col = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_csr = CSRReorder(csr, new_row, new_col);
  ASSERT_EQ(new_csr.num_rows, csr.num_rows);
  ASSERT_EQ(new_csr.num_cols, csr.num_cols);
}

TEST(SpmatTest, TestCSRReorder) {
  _TestCSRReorder<int32_t>();
  _TestCSRReorder<int64_t>();
}