test_csrmm.cc 7.34 KB
Newer Older
1
2
#include <dgl/array.h>
#include <dgl/kernel.h>
3
4
#include <gtest/gtest.h>

5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include "../../src/array/cpu/array_utils.h"  // PairHash
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

namespace {

// Unit tests:
// CSRMM(A, B) == A_mm_B
// CSRSum({A, C}) == A_plus_C
// CSRMask(A, C) = A_mask_C

template <typename IdType, typename DType>
std::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> COOToMap(
    aten::COOMatrix coo, NDArray weights) {
  std::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> map;

  for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
    IdType irow = aten::IndexSelect<IdType>(coo.row, i);
    IdType icol = aten::IndexSelect<IdType>(coo.col, i);
26
27
    IdType ieid =
        aten::COOHasData(coo) ? aten::IndexSelect<IdType>(coo.data, i) : i;
28
29
30
31
32
33
34
35
    DType idata = aten::IndexSelect<DType>(weights, ieid);
    map.insert({{irow, icol}, idata});
  }
  return map;
}

template <typename IdType, typename DType>
bool CSRIsClose(
36
37
    aten::CSRMatrix A, aten::CSRMatrix B, NDArray A_weights, NDArray B_weights,
    DType rtol, DType atol) {
38
39
40
  auto Amap = COOToMap<IdType, DType>(CSRToCOO(A, false), A_weights);
  auto Bmap = COOToMap<IdType, DType>(CSRToCOO(B, false), B_weights);

41
  if (Amap.size() != Bmap.size()) return false;
42
43
44

  for (auto itA : Amap) {
    auto itB = Bmap.find(itA.first);
45
    if (itB == Bmap.end()) return false;
46
47
48
49
50
51
52
53
    if (fabs(itA.second - itB->second) >= rtol * fabs(itA.second) + atol)
      return false;
  }

  return true;
}

template <typename IdType, typename DType>
54
std::pair<aten::CSRMatrix, NDArray> CSR_A(DGLContext ctx = CTX) {
55
56
57
58
59
60
  // matrix([[0. , 0. , 1. , 0.7, 0. ],
  //         [0. , 0. , 0.5, 0.+, 0. ],
  //         [0.4, 0.7, 0. , 0.2, 0. ],
  //         [0. , 0. , 0. , 0. , 0.2]])
  // (0.+ indicates that the entry exists but the value is 0.)
  auto csr = aten::CSRMatrix(
61
      4, 5, NDArray::FromVector(std::vector<IdType>({0, 2, 4, 7, 8}), ctx),
62
63
      NDArray::FromVector(std::vector<IdType>({2, 3, 2, 3, 0, 1, 3, 4}), ctx),
      NDArray::FromVector(std::vector<IdType>({1, 0, 2, 3, 4, 5, 6, 7}), ctx));
64
  auto weights = NDArray::FromVector(
65
      std::vector<DType>({0.7, 1.0, 0.5, 0.0, 0.4, 0.7, 0.2, 0.2}), ctx);
66
67
68
69
  return {csr, weights};
}

template <typename IdType, typename DType>
70
std::pair<aten::CSRMatrix, NDArray> CSR_B(DGLContext ctx = CTX) {
71
72
73
74
75
76
77
  // matrix([[0. , 0.9, 0. , 0.6, 0. , 0.3],
  //         [0. , 0. , 0. , 0. , 0. , 0.4],
  //         [0.+, 0. , 0. , 0. , 0. , 0.9],
  //         [0.8, 0.2, 0.3, 0.2, 0. , 0. ],
  //         [0.2, 0.4, 0. , 0. , 0. , 0. ]])
  // (0.+ indicates that the entry exists but the value is 0.)
  auto csr = aten::CSRMatrix(
78
79
80
      5, 6, NDArray::FromVector(std::vector<IdType>({0, 3, 4, 6, 10, 12}), ctx),
      NDArray::FromVector(
          std::vector<IdType>({1, 3, 5, 5, 0, 5, 0, 1, 2, 3, 0, 1}), ctx));
81
  auto weights = NDArray::FromVector(
82
83
84
      std::vector<DType>(
          {0.9, 0.6, 0.3, 0.4, 0.0, 0.9, 0.8, 0.2, 0.3, 0.2, 0.2, 0.4}),
      ctx);
85
86
87
88
  return {csr, weights};
}

template <typename IdType, typename DType>
89
std::pair<aten::CSRMatrix, NDArray> CSR_C(DGLContext ctx = CTX) {
90
91
92
93
94
  // matrix([[0. , 0. , 0. , 0.2, 0. ],
  //         [0. , 0. , 0. , 0.5, 0.4],
  //         [0. , 0.2, 0. , 0.9, 0.2],
  //         [0. , 1. , 0. , 0.7, 0. ]])
  auto csr = aten::CSRMatrix(
95
      4, 5, NDArray::FromVector(std::vector<IdType>({0, 1, 3, 6, 8}), ctx),
96
97
      NDArray::FromVector(std::vector<IdType>({3, 3, 4, 1, 3, 4, 1, 3}), ctx));
  auto weights = NDArray::FromVector(
98
      std::vector<DType>({0.2, 0.5, 0.4, 0.2, 0.9, 0.2, 1., 0.7}), ctx);
99
100
101
102
  return {csr, weights};
}

template <typename IdType, typename DType>
103
std::pair<aten::CSRMatrix, NDArray> CSR_A_mm_B(DGLContext ctx = CTX) {
104
105
106
107
108
109
  // matrix([[0.56, 0.14, 0.21, 0.14, 0.  , 0.9 ],
  //         [0.+ , 0.+ , 0.+ , 0.+ , 0.  , 0.45],
  //         [0.16, 0.4 , 0.06, 0.28, 0.  , 0.4 ],
  //         [0.04, 0.08, 0.  , 0.  , 0.  , 0.  ]])
  // (0.+ indicates that the entry exists but the value is 0.)
  auto csr = aten::CSRMatrix(
110
111
112
113
114
      4, 6, NDArray::FromVector(std::vector<IdType>({0, 5, 10, 15, 17}), ctx),
      NDArray::FromVector(
          std::vector<IdType>(
              {0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1}),
          ctx));
115
  auto weights = NDArray::FromVector(
116
117
118
119
      std::vector<DType>(
          {0.56, 0.14, 0.21, 0.14, 0.9, 0., 0., 0., 0., 0.45, 0.16, 0.4, 0.06,
           0.28, 0.4, 0.04, 0.08}),
      ctx);
120
121
122
123
  return {csr, weights};
}

template <typename IdType, typename DType>
124
std::pair<aten::CSRMatrix, NDArray> CSR_A_plus_C(DGLContext ctx = CTX) {
125
  auto csr = aten::CSRMatrix(
126
127
128
      4, 5, NDArray::FromVector(std::vector<IdType>({0, 2, 5, 9, 12}), ctx),
      NDArray::FromVector(
          std::vector<IdType>({2, 3, 2, 3, 4, 0, 1, 3, 4, 1, 3, 4}), ctx));
129
  auto weights = NDArray::FromVector(
130
131
132
      std::vector<DType>(
          {1., 0.9, 0.5, 0.5, 0.4, 0.4, 0.9, 1.1, 0.2, 1., 0.7, 0.2}),
      ctx);
133
134
135
136
  return {csr, weights};
}

template <typename DType>
137
NDArray CSR_A_mask_C(DGLContext ctx = CTX) {
138
139
  return NDArray::FromVector(
      std::vector<DType>({0.7, 0.0, 0.0, 0.7, 0.2, 0.0, 0.0, 0.0}), ctx);
140
141
142
}

template <typename IdType, typename DType>
143
void _TestCsrmm(DGLContext ctx = CTX) {
144
145
146
147
  auto A = CSR_A<IdType, DType>(ctx);
  auto B = CSR_B<IdType, DType>(ctx);
  auto A_mm_B = aten::CSRMM(A.first, A.second, B.first, B.second);
  auto A_mm_B2 = CSR_A_mm_B<IdType, DType>(ctx);
148
149
  bool result = CSRIsClose<IdType, DType>(
      A_mm_B.first, A_mm_B2.first, A_mm_B.second, A_mm_B2.second, 1e-4, 1e-4);
150
151
152
153
  ASSERT_TRUE(result);
}

template <typename IdType, typename DType>
154
void _TestCsrsum(DGLContext ctx = CTX) {
155
156
157
158
159
  auto A = CSR_A<IdType, DType>(ctx);
  auto C = CSR_C<IdType, DType>(ctx);
  auto A_plus_C = aten::CSRSum({A.first, C.first}, {A.second, C.second});
  auto A_plus_C2 = CSR_A_plus_C<IdType, DType>(ctx);
  bool result = CSRIsClose<IdType, DType>(
160
161
      A_plus_C.first, A_plus_C2.first, A_plus_C.second, A_plus_C2.second, 1e-4,
      1e-4);
162
163
164
165
  ASSERT_TRUE(result);
}

template <typename IdType, typename DType>
166
void _TestCsrmask(DGLContext ctx = CTX) {
167
168
  auto A = CSR_A<IdType, DType>(ctx);
  auto C = CSR_C<IdType, DType>(ctx);
169
  auto C_coo = CSRToCOO(C.first, false);
170
171
  auto A_mask_C =
      aten::CSRGetData<DType>(A.first, C_coo.row, C_coo.col, A.second, 0);
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
  auto A_mask_C2 = CSR_A_mask_C<DType>(ctx);
  ASSERT_TRUE(ArrayEQ<DType>(A_mask_C, A_mask_C2));
}

TEST(CsrmmTest, TestCsrmm) {
  _TestCsrmm<int32_t, float>(CPU);
  _TestCsrmm<int32_t, double>(CPU);
  _TestCsrmm<int64_t, float>(CPU);
  _TestCsrmm<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
  _TestCsrmm<int32_t, float>(GPU);
  _TestCsrmm<int32_t, double>(GPU);
  _TestCsrmm<int64_t, float>(GPU);
  _TestCsrmm<int64_t, double>(GPU);
#endif
}

TEST(CsrmmTest, TestCsrsum) {
  _TestCsrsum<int32_t, float>(CPU);
  _TestCsrsum<int32_t, double>(CPU);
  _TestCsrsum<int64_t, float>(CPU);
  _TestCsrsum<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
  _TestCsrsum<int32_t, float>(GPU);
  _TestCsrsum<int32_t, double>(GPU);
  _TestCsrsum<int64_t, float>(GPU);
  _TestCsrsum<int64_t, double>(GPU);
#endif
}

TEST(CsrmmTest, TestCsrmask) {
  _TestCsrmask<int32_t, float>(CPU);
  _TestCsrmask<int32_t, double>(CPU);
  _TestCsrmask<int64_t, float>(CPU);
  _TestCsrmask<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
  _TestCsrmask<int32_t, float>(GPU);
  _TestCsrmask<int32_t, double>(GPU);
  _TestCsrmask<int64_t, float>(GPU);
  _TestCsrmask<int64_t, double>(GPU);
#endif
}

};  // namespace