cpu_dot_op_test.cpp 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include <migraphx/half.hpp>

template <class T>
void matmul_test()
{
    migraphx::program p;
    std::vector<T> a     = {-0.00925222, 0.56250403, 0.70107397,  0.75402161,  -0.505885,
                        1.33628943,  -0.11413,   -0.31270559, 1.59336732,  -0.19361027,
                        -0.91620867, 0.40108416, -0.06969921, 0.68483471,  -0.39906632,
                        -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> b = {6.09568541e-01,
                            -6.10527007e-01,
                            3.66646462e-01,
                            1.18951101e-01,
                            5.58777432e-01,
                            -3.21296298e-01,
                            -5.95997198e-01,
                            -5.01425721e-01,
                            -2.84606807e-01,
                            -5.73673557e-01,
                            -8.99430260e-01,
                            -4.25103093e-01,
                            1.53027987e+00,
                            -3.81407415e-04,
                            -3.29650255e-01};
    std::vector<float> c = {-1.56327541e+00,
                            -7.09570140e-01,
                            -5.37424982e-01,
                            -2.22994831e-01,
                            -2.15586437e+00,
                            2.09177941e-03,
                            -1.47279677e+00,
                            2.02627040e-01,
                            -6.04527691e-01,
                            -1.29885596e+00,
                            2.16294914e+00,
                            -1.48101497e-01};
    migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
    auto al = p.add_literal(migraphx::literal{a_shape, a});
    migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
    auto bl = p.add_literal(migraphx::literal{b_shape, b});
    p.add_instruction(migraphx::op::dot{}, al, bl);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<T> results_vector(12);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test<float>)
TEST_CASE_REGISTER(matmul_test<double>)

template <class T>
void matmul_test_ex()
{
    migraphx::program p;
    std::vector<T> a     = {-0.00925222, 0.56250403, 0.70107397,  0.75402161,  -0.505885,
                        1.33628943,  -0.11413,   -0.31270559, 1.59336732,  -0.19361027,
                        -0.91620867, 0.40108416, -0.06969921, 0.68483471,  -0.39906632,
                        -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> b = {6.09568541e-01,
                            -6.10527007e-01,
                            3.66646462e-01,
                            1.18951101e-01,
                            5.58777432e-01,
                            -3.21296298e-01,
                            -5.95997198e-01,
                            -5.01425721e-01,
                            -2.84606807e-01,
                            -5.73673557e-01,
                            -8.99430260e-01,
                            -4.25103093e-01,
                            1.53027987e+00,
                            -3.81407415e-04,
                            -3.29650255e-01};
    std::vector<float> c = {-1.56327541e+00,
                            -7.09570140e-01,
                            -5.37424982e-01,
                            -2.22994831e-01,
                            -2.15586437e+00,
                            2.09177941e-03,
                            -1.47279677e+00,
                            2.02627040e-01,
                            -6.04527691e-01,
                            -1.29885596e+00,
                            2.16294914e+00,
                            -1.48101497e-01};
    migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
    auto al = p.add_literal(migraphx::literal{a_shape, a});
    migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
    auto bl = p.add_literal(migraphx::literal{b_shape, b});
    p.add_instruction(migraphx::op::dot{}, al, bl);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<T> results_vector(12);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test_ex<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>)

TEST_CASE(matmul_mutli_dim_2)
{
    migraphx::program p;
    std::vector<float> m1 = {-0.76234141,
                             0.01368910,
                             -0.86343423,
                             -0.99465282,
                             0.76133268,
                             0.96507140,
                             -0.55893585,
                             0.02625652,
                             0.75171776,
                             0.23112578,
                             0.25624787,
                             -1.50442161};
    migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
    std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
                             -1.06971832, -0.12272917, 1.07094116,  -0.08346820, 1.16820693,
                             -0.95700874, 0.24059691,  0.43326023,  0.78305235,  -0.53506601,
                             -0.69359678, -0.26334436, 1.56292796,  -0.33629175, -1.72693469,
                             0.41435494,  1.52136843,  -0.40699791, -1.59839430};
    migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
    auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});

    p.add_instruction(migraphx::op::dot{}, l1, l2);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {0.18208394,
                                -0.49276402,
                                0.87189133,
                                0.75150114,
                                -0.55909610,
                                1.00521735,
                                -0.95536130,
                                2.27996211,
                                0.06239879,
                                0.74700068,
                                -0.01570983,
                                -0.85920856,
                                -0.59070835,
                                -1.70729902,
                                0.40245487,
                                1.80182751};

    EXPECT(migraphx::verify_range(m, m_res));
}

TEST_CASE(gemm_mutli_dim_2_beta0)
{
    migraphx::program p;
    std::vector<float> m1 = {-0.76234141,
                             0.01368910,
                             -0.86343423,
                             -0.99465282,
                             0.76133268,
                             0.96507140,
                             -0.55893585,
                             0.02625652,
                             0.75171776,
                             0.23112578,
                             0.25624787,
                             -1.50442161};
    migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
    std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
                             -1.06971832, -0.12272917, 1.07094116,  -0.08346820, 1.16820693,
                             -0.95700874, 0.24059691,  0.43326023,  0.78305235,  -0.53506601,
                             -0.69359678, -0.26334436, 1.56292796,  -0.33629175, -1.72693469,
                             0.41435494,  1.52136843,  -0.40699791, -1.59839430};
    migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
    std::vector<float> m3 = {0.18208394,
                             -0.49276402,
                             0.87189133,
                             0.75150114,
                             -0.55909610,
                             1.00521735,
                             -0.95536130,
                             2.27996211,
                             0.06239879,
                             0.74700068,
                             -0.01570983,
                             -0.85920856,
                             -0.59070835,
                             -1.70729902,
                             0.40245487,
                             1.80182751};
    migraphx::shape m3_shape{migraphx::shape::float_type, {2, 2, 4}};
    auto l1     = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2     = p.add_literal(migraphx::literal{m2_shape, m2});
    auto l3     = p.add_literal(migraphx::literal{m3_shape, m3});
    float alpha = 1.0f;
    float beta  = 0.0f;
    p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {0.18208394,
                                -0.49276402,
                                0.87189133,
                                0.75150114,
                                -0.55909610,
                                1.00521735,
                                -0.95536130,
                                2.27996211,
                                0.06239879,
                                0.74700068,
                                -0.01570983,
                                -0.85920856,
                                -0.59070835,
                                -1.70729902,
                                0.40245487,
                                1.80182751};

    EXPECT(migraphx::verify_range(m, m_res));
}

TEST_CASE(gemm_beta_0)
{
    migraphx::program p;
    std::vector<float> m1 = {
        -0.76234141, 0.01368910, -0.86343423, -0.99465282, 0.76133268, 0.96507140};
    migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
    std::vector<float> m2 = {-0.15933632,
                             -0.69594712,
                             -0.06198966,
                             -1.23905184,
                             -0.83672704,
                             -1.06971832,
                             -0.12272917,
                             1.07094116,
                             -0.08346820,
                             1.16820693,
                             -0.95700874,
                             0.24059691};
    migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};

    migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
    std::vector<float> m3 = {0.18208394,
                             -0.49276402,
                             0.87189133,
                             0.75150114,
                             -0.55909610,
                             1.00521735,
                             -0.95536130,
                             2.27996211};
    auto l1               = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2               = p.add_literal(migraphx::literal{m2_shape, m2});
    auto l3               = p.add_literal(migraphx::literal{m3_shape, m3});

    float alpha = 1.0f;
    float beta  = 0.0f;
    p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {0.18208394,
                                -0.49276402,
                                0.87189133,
                                0.75150114,
                                -0.55909610,
                                1.00521735,
                                -0.95536130,
                                2.27996211};

    EXPECT(migraphx::verify_range(m, m_res));
}

TEST_CASE(matmul_mutli_dim_2_3)
{
    migraphx::program p;
    std::vector<float> m1 = {
        -1.93300070, 0.33902698,  -0.45173527, -0.72283069, -0.17177134, 1.62199882,
        0.87052847,  0.14989811,  -0.88969184, -0.18131398, 0.72654339,  -0.57123693,
        0.03852506,  -0.72332085, -1.81844083, -0.33465167, -0.71400352, 0.36883161,
        0.08698452,  0.94974586,  0.40087323,  -0.05448534, 0.03220677,  -1.22494296,
        0.97938472,  -1.43714454, -0.80430904, -0.08098728, 0.31520301,  0.49642169,
        -1.63471091, 0.34390096,  2.81292176,  -0.22666528, 1.54559556,  -1.51075762};
    migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
    std::vector<float> m2 = {
        -0.33170529, 2.26325120,  -0.50639461, 0.64802947,  0.44748888,  0.33768068,
        -0.53621075, 0.34341460,  0.58742520,  -1.13995790, -0.99322535, 0.35447353,
        0.01977110,  -0.10155016, -1.02288245, -0.16575791, -1.47870374, 0.29300008,
        -0.39112198, 1.42303608,  -0.02853060, 1.52610164,  0.53540909,  0.75618998,
        -0.26877787, -1.90886366, 0.30622790,  0.59794535,  1.29795331,  -0.37805803,
        -1.58167176, -1.26966832, 0.27435891,  0.89430347,  0.22854926,  -0.50317658};
    migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
    auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});

    p.add_instruction(migraphx::op::dot{}, l1, l2);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {0.26735861,  -4.30770895, 1.05257728,  -1.19954265, 0.50493170,
                                -0.18729756, 1.09137941,  -1.09298312, 3.42956915,  -0.41681939,
                                0.17833257,  0.26040336,  0.15351280,  1.87632715,  -0.63545406,
                                -0.95467340, -1.74728628, -2.42477030, 0.76262372,  0.15539164,
                                3.32281958,  0.96769613,  0.43727545,  2.43019906};

    EXPECT(migraphx::verify_range(m, m_res));
}

TEST_CASE(gemm_mutli_dim1_2_3)
{
    migraphx::program p;
    std::vector<float> m1 = {
        1.23636469,  -0.47041261, -0.14375651, -0.48371852, 1.16479301,  -0.89361055,
        -0.18569086, 1.10700457,  -1.02632638, 0.82277012,  0.33525769,  0.52825145,
        -1.00141689, 0.45510090,  -0.02675039, -0.60454439, 0.38551153,  -0.01658514,
        0.93059292,  -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
        -1.36200452, 2.25822236,  -1.23416970, 0.12312496,  0.76232760,  -0.83594234,
        1.67418145,  -0.19412936, 1.05261378,  0.66246074,  -1.15233398, 0.16429736};
    migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
    std::vector<float> m2 = {
        -0.87300530, -0.07112838, 0.19196860,  -1.04986840, 1.20348200,  0.31966893,
        1.04805440,  -2.04777729, -0.67906052, -1.17250760, 0.34305044,  -1.01957785,
        -1.12694862, 0.18431338,  -1.63712290, 0.27566931,  -1.11282021, 1.41738919,
        0.47871283,  -1.01980420, 1.00212436,  -0.78740444, -1.65636133, 1.51466547,
        -0.12470397, 0.70404393,  -0.15244797, 0.74288871,  0.07339926,  -1.45811623,
        0.27185845,  0.08804596,  0.99061977,  -1.61752428, 0.29191159,  0.87271953};
    migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
    std::vector<float> m3 = {-1.07692443, 0.85223457,  -0.37266530, 2.31511577,  0.04227017,
                             1.13229428,  -0.52769242, 0.27307182,  -0.47779843, -0.08023168,
                             -0.22862823, 0.81489871,  1.13139581,  1.13860467,  0.24309065,
                             0.26533729,  0.49106772,  -1.18860493, 0.27842449,  1.03568141,
                             0.49759611,  0.10021662,  0.00592602,  0.90862000};
    migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};

    auto l1        = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2        = p.add_literal(migraphx::literal{m2_shape, m2});
    auto l3        = p.add_literal(migraphx::literal{m3_shape, m3});
    float alpha    = 0.35;
    float beta     = 0.41;
    auto m12_alpha = p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2);
    auto l_beta    = p.add_literal(beta);
    auto b_beta    = p.add_instruction(migraphx::op::scalar{m12_alpha->get_shape()}, l_beta);
    auto m3_beta   = p.add_instruction(migraphx::op::mul{}, b_beta, l3);
    p.add_instruction(migraphx::op::add{}, m3_beta, m12_alpha);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099,  -0.43711586,
                                0.50928632,  0.06919868, -0.80382802, -0.05125718, -0.06685650,
                                -0.06972163, 0.32407764, 0.45677396,  0.25909489,  0.56911252,
                                -0.17183724, 0.10858734, 0.39406289,  0.04662959,  1.07979824,
                                0.40355016,  0.52410648, -0.31728447, 1.09550845};

    EXPECT(migraphx::verify_range(m, m_res));
}

TEST_CASE(gemm_mutli_3args)
{
    migraphx::program p;
    std::vector<float> m1 = {
        1.23636469,  -0.47041261, -0.14375651, -0.48371852, 1.16479301,  -0.89361055,
        -0.18569086, 1.10700457,  -1.02632638, 0.82277012,  0.33525769,  0.52825145,
        -1.00141689, 0.45510090,  -0.02675039, -0.60454439, 0.38551153,  -0.01658514,
        0.93059292,  -0.54595188, -0.04911005, -0.91397221, -0.83127477, -1.57685603,
        -1.36200452, 2.25822236,  -1.23416970, 0.12312496,  0.76232760,  -0.83594234,
        1.67418145,  -0.19412936, 1.05261378,  0.66246074,  -1.15233398, 0.16429736};
    migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
    std::vector<float> m2 = {
        -0.87300530, -0.07112838, 0.19196860,  -1.04986840, 1.20348200,  0.31966893,
        1.04805440,  -2.04777729, -0.67906052, -1.17250760, 0.34305044,  -1.01957785,
        -1.12694862, 0.18431338,  -1.63712290, 0.27566931,  -1.11282021, 1.41738919,
        0.47871283,  -1.01980420, 1.00212436,  -0.78740444, -1.65636133, 1.51466547,
        -0.12470397, 0.70404393,  -0.15244797, 0.74288871,  0.07339926,  -1.45811623,
        0.27185845,  0.08804596,  0.99061977,  -1.61752428, 0.29191159,  0.87271953};
    migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
    std::vector<float> m3 = {-1.07692443, 0.85223457,  -0.37266530, 2.31511577,  0.04227017,
                             1.13229428,  -0.52769242, 0.27307182,  -0.47779843, -0.08023168,
                             -0.22862823, 0.81489871,  1.13139581,  1.13860467,  0.24309065,
                             0.26533729,  0.49106772,  -1.18860493, 0.27842449,  1.03568141,
                             0.49759611,  0.10021662,  0.00592602,  0.90862000};
    migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};

    auto l1     = p.add_literal(migraphx::literal{m1_shape, m1});
    auto l2     = p.add_literal(migraphx::literal{m2_shape, m2});
    auto l3     = p.add_literal(migraphx::literal{m3_shape, m3});
    float alpha = 0.35;
    float beta  = 0.41;
    p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
    p.compile(migraphx::cpu::target{});
    auto result = p.eval({});
    std::vector<float> m;
    result.visit([&](auto output) { m.assign(output.begin(), output.end()); });

    std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099,  -0.43711586,
                                0.50928632,  0.06919868, -0.80382802, -0.05125718, -0.06685650,
                                -0.06972163, 0.32407764, 0.45677396,  0.25909489,  0.56911252,
                                -0.17183724, 0.10858734, 0.39406289,  0.04662959,  1.07979824,
                                0.40355016,  0.52410648, -0.31728447, 1.09550845};

    EXPECT(migraphx::verify_range(m, m_res));
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }