test_spmat_csr.cc 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <gtest/gtest.h>
#include <dgl/array.h>
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

template <typename IDX>
11
aten::CSRMatrix CSR1(DLContext ctx = CTX) {
12
13
14
15
16
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
17
18
  return aten::CSRMatrix(
      4, 5,
19
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, ctx),
20
21
      aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 3, 2}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 4, 1}), sizeof(IDX)*8, ctx),
22
      false);
23
24
25
}

template <typename IDX>
26
aten::CSRMatrix CSR2(DLContext ctx = CTX) {
27
28
29
30
31
32
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
33
34
  return aten::CSRMatrix(
      4, 5,
35
36
37
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
38
      false);
39
40
41
}

template <typename IDX>
42
aten::COOMatrix COO1(DLContext ctx = CTX) {
43
44
45
46
47
48
49
  // [[0, 1, 1, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 3, 1, 4]
  // row : [0, 2, 0, 1, 2]
  // col : [1, 2, 2, 0, 3]
50
51
  return aten::COOMatrix(
      4, 5,
52
53
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3}), sizeof(IDX)*8, ctx));
54
55
56
}

template <typename IDX>
57
aten::COOMatrix COO2(DLContext ctx = CTX) {
58
59
60
61
62
63
64
65
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3, 1, 4]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [1, 2, 2, 0, 3, 2]
66
67
  return aten::COOMatrix(
      4, 5,
68
69
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 3, 2}), sizeof(IDX)*8, ctx));
70
71
}

72
template <typename IDX>
73
aten::CSRMatrix SR_CSR3(DLContext ctx) {
74
75
76
77
78
79
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
80
81
82
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 1, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
83
84
85
86
      false);
}

template <typename IDX>
87
aten::CSRMatrix SRC_CSR3(DLContext ctx) {
88
89
90
91
92
93
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  return aten::CSRMatrix(
      4, 5,
94
95
96
      aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 0, 5, 3, 1, 4}), sizeof(IDX)*8, ctx),
97
98
99
100
      false);
}

template <typename IDX>
101
aten::COOMatrix COO3(DLContext ctx) {
102
103
104
105
106
107
108
109
110
  // has duplicate entries
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // row : [0, 2, 0, 1, 2, 0]
  // col : [2, 2, 1, 0, 3, 2]
  return aten::COOMatrix(
      4, 5,
111
112
      aten::VecToIdArray(std::vector<IDX>({0, 2, 0, 1, 2, 0}), sizeof(IDX)*8, ctx),
      aten::VecToIdArray(std::vector<IDX>({2, 2, 1, 0, 3, 2}), sizeof(IDX)*8, ctx));
113
114
}

115
}  // namespace
116
117

template <typename IDX>
118
119
void _TestCSRIsNonZero(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
120
121
  ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));
  ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));
122
123
  IdArray r = aten::VecToIdArray(std::vector<IDX>({2, 2, 0, 0}), sizeof(IDX)*8, ctx);
  IdArray c = aten::VecToIdArray(std::vector<IDX>({1, 1, 1, 3}), sizeof(IDX)*8, ctx);
124
  IdArray x = aten::CSRIsNonZero(csr, r, c);
125
  IdArray tx = aten::VecToIdArray(std::vector<IDX>({0, 0, 1, 0}), sizeof(IDX)*8, ctx);
126
127
128
129
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRIsNonZero) {
130
131
132
133
134
135
  _TestCSRIsNonZero<int32_t>(CPU);
  _TestCSRIsNonZero<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRIsNonZero<int32_t>(GPU);
  _TestCSRIsNonZero<int64_t>(GPU);
#endif
136
137
138
}

template <typename IDX>
139
140
void _TestCSRGetRowNNZ(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
141
142
  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 0), 3);
  ASSERT_EQ(aten::CSRGetRowNNZ(csr, 3), 0);
143
  IdArray r = aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX)*8, ctx);
144
  IdArray x = aten::CSRGetRowNNZ(csr, r);
145
  IdArray tx = aten::VecToIdArray(std::vector<IDX>({3, 0}), sizeof(IDX)*8, ctx);
146
147
148
149
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowNNZ) {
150
151
152
153
154
155
  _TestCSRGetRowNNZ<int32_t>(CPU);
  _TestCSRGetRowNNZ<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowNNZ<int32_t>(GPU);
  _TestCSRGetRowNNZ<int64_t>(GPU);
#endif
156
157
158
}

template <typename IDX>
159
160
void _TestCSRGetRowColumnIndices(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
161
  auto x = aten::CSRGetRowColumnIndices(csr, 0);
162
  auto tx = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, ctx);
163
164
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowColumnIndices(csr, 1);
165
  tx = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, ctx);
166
167
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowColumnIndices(csr, 3);
168
  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
169
170
171
172
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowColumnIndices) {
173
174
175
176
177
178
  _TestCSRGetRowColumnIndices<int32_t>(CPU);
  _TestCSRGetRowColumnIndices<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowColumnIndices<int32_t>(GPU);
  _TestCSRGetRowColumnIndices<int64_t>(GPU);
#endif
179
180
181
}

template <typename IDX>
182
183
void _TestCSRGetRowData(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
184
  auto x = aten::CSRGetRowData(csr, 0);
185
  auto tx = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
186
187
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowData(csr, 1);
188
  tx = aten::VecToIdArray(std::vector<IDX>({3}), sizeof(IDX)*8, ctx);
189
190
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetRowData(csr, 3);
191
  tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
192
193
194
195
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
}

TEST(SpmatTest, TestCSRGetRowData) {
196
197
198
199
200
201
  _TestCSRGetRowData<int32_t>(CPU);
  _TestCSRGetRowData<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetRowData<int32_t>(GPU);
  _TestCSRGetRowData<int64_t>(GPU);
#endif
202
203
204
}

template <typename IDX>
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
void _TestCSRGetData(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  // test get all data
  auto x = aten::CSRGetAllData(csr, 0, 0);
  auto tx = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
  x = aten::CSRGetAllData(csr, 0, 2);
  tx = aten::VecToIdArray(std::vector<IDX>({2, 5}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

  // test get data
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::CSRGetData(csr, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
220
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
221
222
223
224
225
226
227

  // test get data on sorted
  csr = aten::CSRSort(csr);
  r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  x = aten::CSRGetData(csr, r, c);
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
228
229
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));

230
231
232
  // test get data w/ broadcasting
  r = aten::VecToIdArray(std::vector<IDX>({0}), sizeof(IDX)*8, ctx);
  c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
233
  x = aten::CSRGetData(csr, r, c);
234
  tx = aten::VecToIdArray(std::vector<IDX>({-1, 0, 2}), sizeof(IDX)*8, ctx);
235
  ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
236

237
238
}

239
240
241
242
243
244
TEST(SpmatTest, CSRGetData) {
  _TestCSRGetData<int32_t>(CPU);
  _TestCSRGetData<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetData<int32_t>(GPU);
#endif
245
246
247
}

template <typename IDX>
248
249
250
251
void _TestCSRGetDataAndIndices(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
252
  auto x = aten::CSRGetDataAndIndices(csr, r, c);
253
254
255
  auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 2}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
256
257
258
259
260
  ASSERT_TRUE(ArrayEQ<IDX>(x[0], tr));
  ASSERT_TRUE(ArrayEQ<IDX>(x[1], tc));
  ASSERT_TRUE(ArrayEQ<IDX>(x[2], td));
}

261
262
263
264
265
266
267
TEST(SpmatTest, CSRGetDataAndIndices) {
  _TestCSRGetDataAndIndices<int32_t>(CPU);
  _TestCSRGetDataAndIndices<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRGetDataAndIndices<int32_t>(GPU);
  _TestCSRGetDataAndIndices<int64_t>(GPU);
#endif
268
269
270
}

template <typename IDX>
271
272
void _TestCSRTranspose(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
273
274
275
276
277
278
279
280
281
  auto csr_t = aten::CSRTranspose(csr);
  // [[0, 1, 0, 0],
  //  [1, 0, 0, 0],
  //  [2, 0, 1, 0],
  //  [0, 0, 1, 0],
  //  [0, 0, 0, 0]]
  // data: [3, 0, 2, 5, 1, 4]
  ASSERT_EQ(csr_t.num_rows, 5);
  ASSERT_EQ(csr_t.num_cols, 4);
282
283
284
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 5, 6, 6}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0, 0, 0, 2, 2}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({3, 0, 2, 5, 1, 4}), sizeof(IDX)*8, ctx);
285
286
287
288
289
290
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(csr_t.data, td));
}

TEST(SpmatTest, TestCSRTranspose) {
291
292
293
294
295
  _TestCSRTranspose<int32_t>(CPU);
  _TestCSRTranspose<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRTranspose<int32_t>(GPU);
#endif
296
297
298
}

template <typename IDX>
299
300
void _TestCSRToCOO(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
301
302
303
304
  {
  auto coo = CSRToCOO(csr, false);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
305
  ASSERT_TRUE(coo.row_sorted);
306
  auto tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
307
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, csr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.data, csr.data));

  // convert from sorted csr
  auto s_csr = CSRSort(csr);
  coo = CSRToCOO(s_csr, false);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
  ASSERT_TRUE(coo.row_sorted);
  ASSERT_TRUE(coo.col_sorted);
  tr = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 1, 2, 2}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tr));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, s_csr.indices));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.data, s_csr.data));
322
323
324
325
326
  }
  {
  auto coo = CSRToCOO(csr, true);
  ASSERT_EQ(coo.num_rows, 4);
  ASSERT_EQ(coo.num_cols, 5);
327
  auto tcoo = COO2<IDX>(ctx);
328
329
330
331
332
  ASSERT_TRUE(ArrayEQ<IDX>(coo.row, tcoo.row));
  ASSERT_TRUE(ArrayEQ<IDX>(coo.col, tcoo.col));
  }
}

333
TEST(SpmatTest, CSRToCOO) {
334
335
336
337
338
  _TestCSRToCOO<int32_t>(CPU);
  _TestCSRToCOO<int64_t>(CPU);
#if DGL_USE_CUDA
  _TestCSRToCOO<int32_t>(GPU);
#endif
339
340
341
}

template <typename IDX>
342
343
void _TestCSRSliceRows(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
344
345
346
347
348
349
350
  auto x = aten::CSRSliceRows(csr, 1, 4);
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 1, 1, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [3, 1, 4]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 5);
351
352
353
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 3, 3}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({0, 2, 3}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({3, 1, 4}), sizeof(IDX)*8, ctx);
354
355
356
357
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));

358
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
359
360
361
362
363
  x = aten::CSRSliceRows(csr, r);
  // [[0, 1, 2, 0, 0],
  //  [1, 0, 0, 0, 0],
  //  [0, 0, 0, 0, 0]]
  // data: [0, 2, 5, 3]
364
365
366
  tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 4}), sizeof(IDX)*8, ctx);
  ti = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0}), sizeof(IDX)*8, ctx);
  td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3}), sizeof(IDX)*8, ctx);
367
368
369
370
371
372
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}

TEST(SpmatTest, TestCSRSliceRows) {
373
374
375
376
377
378
  _TestCSRSliceRows<int32_t>(CPU);
  _TestCSRSliceRows<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSliceRows<int32_t>(GPU);
  _TestCSRSliceRows<int64_t>(GPU);
#endif
379
380
381
}

template <typename IDX>
382
383
384
385
386
387
void _TestCSRSliceMatrix(DLContext ctx) {
  auto csr = CSR2<IDX>(ctx);
  {
  // square
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX)*8, ctx);
388
389
390
391
392
393
394
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[1, 2, 0],
  //  [0, 0, 0],
  //  [0, 0, 0]]
  // data: [0, 2, 5]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 3);
395
396
397
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({0, 1, 1}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
398
399
400
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
  }
  {
  // non-square
  auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[0, 1],
  //  [1, 0],
  //  [0, 0]]
  // data: [0, 3]
  ASSERT_EQ(x.num_rows, 3);
  ASSERT_EQ(x.num_cols, 2);
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 2}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({0, 3}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
  }
  {
  // empty slice
  auto r = aten::VecToIdArray(std::vector<IDX>({2, 3}), sizeof(IDX)*8, ctx);
  auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
  auto x = aten::CSRSliceMatrix(csr, r, c);
  // [[0, 0],
  //  [0, 0]]
  // data: []
  ASSERT_EQ(x.num_rows, 2);
  ASSERT_EQ(x.num_cols, 2);
  auto tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
  auto ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  auto td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
  ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
  ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
  ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
  }
437
438
}

439
440
441
442
443
444
TEST(SpmatTest, CSRSliceMatrix) {
  _TestCSRSliceMatrix<int32_t>(CPU);
  _TestCSRSliceMatrix<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSliceMatrix<int32_t>(GPU);
#endif
445
446
447
}

template <typename IDX>
448
449
void _TestCSRHasDuplicate(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
450
  ASSERT_FALSE(aten::CSRHasDuplicate(csr));
451
  csr = CSR2<IDX>(ctx);
452
453
454
  ASSERT_TRUE(aten::CSRHasDuplicate(csr));
}

455
456
457
458
459
460
TEST(SpmatTest, CSRHasDuplicate) {
  _TestCSRHasDuplicate<int32_t>(CPU);
  _TestCSRHasDuplicate<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRHasDuplicate<int32_t>(GPU);
#endif
461
462
}

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
template <typename IDX>
void _TestCSRSort(DLContext ctx) {
  auto csr = CSR1<IDX>(ctx);
  ASSERT_FALSE(aten::CSRIsSorted(csr));
  auto csr1 = aten::CSRSort(csr);
  ASSERT_FALSE(aten::CSRIsSorted(csr));
  ASSERT_TRUE(aten::CSRIsSorted(csr1));
  ASSERT_TRUE(csr1.sorted);
  aten::CSRSort_(&csr);
  ASSERT_TRUE(aten::CSRIsSorted(csr));
  ASSERT_TRUE(csr.sorted);
  csr = CSR2<IDX>(ctx);
  ASSERT_TRUE(aten::CSRIsSorted(csr));
}

TEST(SpmatTest, CSRSort) {
  _TestCSRSort<int32_t>(CPU);
  _TestCSRSort<int64_t>(CPU);
#ifdef DGL_USE_CUDA
  _TestCSRSort<int32_t>(GPU);
#endif
}

Da Zheng's avatar
Da Zheng committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
template <typename IDX>
void _TestCSRReorder() {
  auto csr = CSR2<IDX>();
  auto new_row = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_col = aten::VecToIdArray(
    std::vector<IDX>({2, 0, 4, 3, 1}), sizeof(IDX)*8, CTX);
  auto new_csr = CSRReorder(csr, new_row, new_col);
  ASSERT_EQ(new_csr.num_rows, csr.num_rows);
  ASSERT_EQ(new_csr.num_cols, csr.num_cols);
}

TEST(SpmatTest, TestCSRReorder) {
  _TestCSRReorder<int32_t>();
  _TestCSRReorder<int64_t>();
}