test_spmm.cc 5.56 KB
Newer Older
1
#if !defined(_WIN32)
Zhi Lin's avatar
Zhi Lin committed
2
#include <../src/array/cpu/spmm.h>
3
4
5
#include <dgl/array.h>
#include <gtest/gtest.h>
#include <time.h>
6

7
#include <random>
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "./common.h"

using namespace dgl;
using namespace dgl::runtime;

int sizes[] = {1, 7, 8, 9, 31, 32, 33, 54, 63, 64, 65, 256, 257};
namespace ns_op = dgl::aten::cpu::op;
namespace {

template <class T>
void GenerateData(T* data, int dim, T mul) {
  for (int i = 0; i < dim; i++) {
    data[i] = (i + 1) * mul;
  }
}

template <class T>
void GenerateRandomData(T* data, int dim) {
  std::mt19937 rng(std::random_device{}());
  std::uniform_int_distribution<> dist(0, 10000);
  for (int i = 0; i < dim; i++) {
    data[i] = (dist(rng) / 100);
  }
}

template <class T>
void GenerateZeroData(T* data, int dim) {
  for (int i = 0; i < dim; i++) {
    data[i] = 0;
  }
}

template <class T>
void Copy(T* exp, T* out, T* hs, int dim) {
  for (int i = 0; i < dim; i++) {
    exp[i] = out[i] + hs[i];
  }
}

template <class T>
void Add(T* exp, T* out, T* lhs, T* rhs, int dim) {
  for (int i = 0; i < dim; i++) {
    exp[i] = out[i] + lhs[i] + rhs[i];
  }
}

template <class T>
void Sub(T* exp, T* out, T* lhs, T* rhs, int dim) {
  for (int i = 0; i < dim; i++) {
    exp[i] = out[i] + lhs[i] - rhs[i];
  }
}

template <class T>
void Mul(T* exp, T* out, T* lhs, T* rhs, int dim) {
  for (int i = 0; i < dim; i++) {
    exp[i] = (out[i] + (lhs[i] * rhs[i]));
  }
}

template <class T>
void Div(T* exp, T* out, T* lhs, T* rhs, int dim) {
  for (int i = 0; i < dim; i++) {
    exp[i] = (out[i] + (lhs[i] / rhs[i]));
  }
}

template <class T>
77
void CheckResult(T* exp, T* out, int dim) {
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  for (int i = 0; i < dim; i++) {
    ASSERT_TRUE(exp[i] == out[i]);
  }
}

}  // namespace

template <typename IDX>
void _TestSpmmCopyLhs() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], lhs[dim];
    GenerateZeroData(out, dim);
    GenerateRandomData(lhs, dim);

    // Calculation of expected output - 'exp'
    Copy(exp, out, lhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::CopyLhs<IDX>::Call(lhs + k, nullptr);
    }

101
    CheckResult(exp, out, dim);
102
103
104
105
106
107
  }
}

TEST(SpmmTest, TestSpmmCopyLhs) {
  _TestSpmmCopyLhs<float>();
  _TestSpmmCopyLhs<double>();
108
  _TestSpmmCopyLhs<BFloat16>();
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
}

template <typename IDX>
void _TestSpmmCopyRhs() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], rhs[dim];
    GenerateZeroData(out, dim);
    GenerateRandomData(rhs, dim);

    // Calculation of expected output - 'exp'
    Copy(exp, out, rhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::CopyRhs<IDX>::Call(nullptr, rhs + k);
    }

127
    CheckResult(exp, out, dim);
128
129
130
131
132
133
  }
}

TEST(SpmmTest, TestSpmmCopyRhs) {
  _TestSpmmCopyRhs<float>();
  _TestSpmmCopyRhs<double>();
134
  _TestSpmmCopyRhs<BFloat16>();
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
}

template <typename IDX>
void _TestSpmmAdd() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], lhs[dim], rhs[dim];
    GenerateZeroData(out, dim);
    GenerateRandomData(lhs, dim);
    GenerateRandomData(rhs, dim);

    // Calculation of expected output - 'exp'
    Add(exp, out, lhs, rhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::Add<IDX>::Call(lhs + k, rhs + k);
    }

154
    CheckResult(exp, out, dim);
155
156
157
158
159
160
  }
}

TEST(SpmmTest, TestSpmmAdd) {
  _TestSpmmAdd<float>();
  _TestSpmmAdd<double>();
161
  _TestSpmmAdd<BFloat16>();
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
}

template <typename IDX>
void _TestSpmmSub() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], lhs[dim], rhs[dim];
    GenerateZeroData(out, dim);
    GenerateRandomData(lhs, dim);
    GenerateRandomData(rhs, dim);

    // Calculation of expected output - 'exp'
    Sub(exp, out, lhs, rhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::Sub<IDX>::Call(lhs + k, rhs + k);
    }

181
    CheckResult(exp, out, dim);
182
183
184
185
186
187
  }
}

TEST(SpmmTest, TestSpmmSub) {
  _TestSpmmSub<float>();
  _TestSpmmSub<double>();
188
  _TestSpmmSub<BFloat16>();
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
}

template <typename IDX>
void _TestSpmmMul() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], lhs[dim], rhs[dim];
    GenerateZeroData(out, dim);
    GenerateRandomData(lhs, dim);
    GenerateRandomData(rhs, dim);

    // Calculation of expected output - 'exp'
    Mul(exp, out, lhs, rhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::Mul<IDX>::Call(lhs + k, rhs + k);
    }

208
    CheckResult(exp, out, dim);
209
210
211
212
213
214
  }
}

TEST(SpmmTest, TestSpmmMul) {
  _TestSpmmMul<float>();
  _TestSpmmMul<double>();
215
  _TestSpmmMul<BFloat16>();
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
}

template <typename IDX>
void _TestSpmmDiv() {
  for (size_t i = 0; i < sizeof(sizes) / sizeof(int); i++) {
    int dim = sizes[i];
    IDX out[dim], exp[dim], lhs[dim], rhs[dim];
    GenerateZeroData(out, dim);
    GenerateData(lhs, dim, (IDX)15);
    GenerateData(rhs, dim, (IDX)1);

    // Calculation of expected output - 'exp'
    Div(exp, out, lhs, rhs, dim);

    // Calculation of output using legacy path - 'out'
    for (int k = 0; k < dim; k++) {
      out[k] += ns_op::Div<IDX>::Call(lhs + k, rhs + k);
    }

235
    CheckResult(exp, out, dim);
236
237
238
239
240
241
  }
}

TEST(SpmmTest, TestSpmmDiv) {
  _TestSpmmDiv<float>();
  _TestSpmmDiv<double>();
242
  _TestSpmmDiv<BFloat16>();
243
}
244
#endif  // _WIN32