cpu_ops_test.cpp 32.7 KB
Newer Older
1
2
#include <iostream>
#include <vector>
Paul's avatar
Paul committed
3
4
5
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/cpu/cpu_target.hpp>
6
#include "test.hpp"
Scott Thornton's avatar
Scott Thornton committed
7
8
#include "verify.hpp"

9
10
11
12
void batch_norm_inference_test()
{
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {4}};
13
14
15
16
17
18
    auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}});
    auto gamma = p.add_literal(migraph::literal{s, {1}});
    auto beta = p.add_literal(migraph::literal{s, {0}});
    auto mean = p.add_literal(migraph::literal{s, {0}});
    auto variance = p.add_literal(migraph::literal{s, {1}});
    p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta);
19
20
    p.compile(migraph::cpu::cpu_target{});
    auto result = p.eval({});
21
22
23
24
    std::vector<float> result_vector(4);
    result.visit([&](auto output) {result_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = { 1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
    EXPECT(test::verify_range(result_vector, gold));
25
26
}

27
28
void exp_test()
{
Paul's avatar
Paul committed
29
30
31
32
33
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    p.add_instruction(migraph::exp{}, l);
    p.compile(migraph::cpu::cpu_target{});
34
35
    auto result = p.eval({});
    std::vector<float> results_vector(3);
36
37
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
Scott Thornton's avatar
Scott Thornton committed
38
    EXPECT(test::verify_range(results_vector, gold));
39
40
}

41
42
void sin_test()
{
Paul's avatar
Paul committed
43
44
45
46
47
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    p.add_instruction(migraph::sin{}, l);
    p.compile(migraph::cpu::cpu_target{});
48
49
    auto result = p.eval({});
    std::vector<float> results_vector(3);
50
51
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
Scott Thornton's avatar
Scott Thornton committed
52
    EXPECT(test::verify_range(results_vector, gold));
53
54
}

55
56
void cos_test()
{
Paul's avatar
Paul committed
57
58
59
60
61
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    p.add_instruction(migraph::cos{}, l);
    p.compile(migraph::cpu::cpu_target{});
62
63
    auto result = p.eval({});
    std::vector<float> results_vector(3);
64
65
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
Scott Thornton's avatar
Scott Thornton committed
66
    EXPECT(test::verify_range(results_vector, gold));
67
68
}

69
70
void tan_test()
{
Paul's avatar
Paul committed
71
72
73
74
75
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    p.add_instruction(migraph::tan{}, l);
    p.compile(migraph::cpu::cpu_target{});
76
77
    auto result = p.eval({});
    std::vector<float> results_vector(3);
78
79
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
Scott Thornton's avatar
Scott Thornton committed
80
    EXPECT(test::verify_range(results_vector, gold));
81
82
}

83
84
void add_test()
{
Paul's avatar
Paul committed
85
86
87
88
89
90
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
    p.add_instruction(migraph::add{}, l1, l2);
    p.compile(migraph::cpu::cpu_target{});
91
92
93
94
95
96
97
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0, 2, 4};
    EXPECT(test::verify_range(results_vector, gold));
}

98
99
void broadcast_test()
{
Paul's avatar
Paul committed
100
101
    migraph::program p;
    migraph::shape a_shape{migraph::shape::int32_type, {2, 2}};
102
    std::vector<int32_t> a_data{0, 0, 0, 0};
Paul's avatar
Paul committed
103
    migraph::shape b_shape{migraph::shape::int32_type, {2}};
104
    std::vector<int32_t> b_data{-2, -3};
105
    uint64_t axis = 0;
Paul's avatar
Paul committed
106
107
108
109
    auto l1       = p.add_literal(migraph::literal{a_shape, a_data});
    auto l2       = p.add_literal(migraph::literal{b_shape, b_data});
    p.add_instruction(migraph::broadcast{axis}, l1, l2);
    p.compile(migraph::cpu::cpu_target{});
110
    auto result = p.eval({});
Paul's avatar
Paul committed
111
    auto output = result.get<int32_t>();
Paul's avatar
Paul committed
112
113
114
115
    EXPECT(output(0, 0) == -2);
    EXPECT(output(0, 1) == -2);
    EXPECT(output(1, 0) == -3);
    EXPECT(output(1, 1) == -3);
116
117
118
}
void add_broadcast_test()
{
Paul's avatar
Paul committed
119
120
    migraph::program p;
    migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
121
    std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
Paul's avatar
Paul committed
122
    migraph::shape b_shape{migraph::shape::float_type, {2, 2}};
123
    std::vector<float> b_data{0, -1, -2, -3};
124
    uint64_t axis = 0;
Paul's avatar
Paul committed
125
126
127
128
129
    auto l1       = p.add_literal(migraph::literal{a_shape, a_data});
    auto l2       = p.add_literal(migraph::literal{b_shape, b_data});
    auto l3       = p.add_instruction(migraph::broadcast{axis}, l1, l2);
    p.add_instruction(migraph::add{}, l1, l3);
    p.compile(migraph::cpu::cpu_target{});
130
    auto result = p.eval({});
Paul's avatar
Paul committed
131
    EXPECT(result.get_shape().packed());
132
133
    std::vector<float> results_vector(12);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
134
    std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
135
136
137
    EXPECT(test::verify_range(results_vector, gold));
}

138
139
void sub_test()
{
Paul's avatar
Paul committed
140
141
142
143
144
145
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
    p.add_instruction(migraph::sub{}, l1, l2);
    p.compile(migraph::cpu::cpu_target{});
146
147
148
149
150
151
152
153
154
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-2, -2, -2};
    EXPECT(test::verify_range(results_vector, gold));
}

void mul_test()
{
Paul's avatar
Paul committed
155
156
157
158
159
160
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l1 = p.add_literal(migraph::literal{s, {-1, 0, 1}});
    auto l2 = p.add_literal(migraph::literal{s, {1, 2, 3}});
    p.add_instruction(migraph::mul{}, l1, l2);
    p.compile(migraph::cpu::cpu_target{});
161
162
163
164
165
166
167
168
169
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1, 0, 3};
    EXPECT(test::verify_range(results_vector, gold));
}

void div_test()
{
Paul's avatar
Paul committed
170
171
172
173
174
175
    migraph::program p;
    migraph::shape s{migraph::shape::float_type, {3}};
    auto l1 = p.add_literal(migraph::literal{s, {-1.0f, 0.5f, 1.0f}});
    auto l2 = p.add_literal(migraph::literal{s, {1.0f, 2.0f, 4.0f}});
    p.add_instruction(migraph::div{}, l1, l2);
    p.compile(migraph::cpu::cpu_target{});
176
177
178
179
180
181
182
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.f, 0.25f, 0.25f};
    EXPECT(test::verify_range(results_vector, gold));
}

183
184
void reshape_test()
{
Paul's avatar
Paul committed
185
    migraph::shape a_shape{migraph::shape::float_type, {24, 1, 1, 1}};
186
187
188
    std::vector<float> data(24);
    std::iota(data.begin(), data.end(), -3);
    {
Paul's avatar
Paul committed
189
190
        migraph::program p;
        auto l                         = p.add_literal(migraph::literal{a_shape, data});
191
        std::vector<int64_t> new_shape = {8, 3, 1, 1};
Paul's avatar
Paul committed
192
193
        p.add_instruction(migraph::reshape{new_shape}, l);
        p.compile(migraph::cpu::cpu_target{});
194
195
        auto result = p.eval({});
        std::vector<float> results_vector(3);
196
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
197
        EXPECT(test::verify_range(results_vector, data));
198
199
    }
    {
Paul's avatar
Paul committed
200
201
        migraph::program p;
        auto l                         = p.add_literal(migraph::literal{a_shape, data});
202
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
Paul's avatar
Paul committed
203
204
        p.add_instruction(migraph::reshape{new_shape}, l);
        p.compile(migraph::cpu::cpu_target{});
205
206
        auto result = p.eval({});
        std::vector<float> results_vector(3);
207
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
208
        EXPECT(test::verify_range(results_vector, data));
209
210
    }
    {
Paul's avatar
Paul committed
211
212
        migraph::program p;
        auto l                         = p.add_literal(migraph::literal{a_shape, data});
213
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
Paul's avatar
Paul committed
214
215
        p.add_instruction(migraph::reshape{new_shape}, l);
        p.compile(migraph::cpu::cpu_target{});
216
217
        auto result = p.eval({});
        std::vector<float> results_vector(3);
218
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
219
        EXPECT(test::verify_range(results_vector, data));
220
221
222
    }
}

223
224
void gemm_test()
{
Paul's avatar
Paul committed
225
    migraph::program p;
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397,  0.75402161,  -0.505885,
                            1.33628943,  -0.11413,   -0.31270559, 1.59336732,  -0.19361027,
                            -0.91620867, 0.40108416, -0.06969921, 0.68483471,  -0.39906632,
                            -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> b = {6.09568541e-01,
                            -6.10527007e-01,
                            3.66646462e-01,
                            1.18951101e-01,
                            5.58777432e-01,
                            -3.21296298e-01,
                            -5.95997198e-01,
                            -5.01425721e-01,
                            -2.84606807e-01,
                            -5.73673557e-01,
                            -8.99430260e-01,
                            -4.25103093e-01,
                            1.53027987e+00,
                            -3.81407415e-04,
                            -3.29650255e-01};
    std::vector<float> c = {-1.56327541e+00,
                            -7.09570140e-01,
                            -5.37424982e-01,
                            -2.22994831e-01,
                            -2.15586437e+00,
                            2.09177941e-03,
                            -1.47279677e+00,
                            2.02627040e-01,
                            -6.04527691e-01,
                            -1.29885596e+00,
                            2.16294914e+00,
                            -1.48101497e-01};
Paul's avatar
Paul committed
257
258
259
260
261
262
    migraph::shape a_shape{migraph::shape::float_type, {4, 5}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
    migraph::shape b_shape{migraph::shape::float_type, {5, 3}};
    auto bl = p.add_literal(migraph::literal{b_shape, b});
    p.add_instruction(migraph::gemm{}, al, bl);
    p.compile(migraph::cpu::cpu_target{});
263
264
    auto result = p.eval({});
    std::vector<float> results_vector(12);
265
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
266
    float tol = 1e-6;
267
268
269
    for(int i = 0; i < results_vector.size(); i++)
    {
        EXPECT(std::abs(results_vector[i] - c[i]) < tol);
270
271
272
    }
}

273
274
void maxpool_test()
{
275
    migraph::program p;
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    std::vector<float> a = {
        -2.1314404,  -1.63041711, 1.54562736,  1.04625261,  -1.42931843, -0.48703974, 0.4065806,
        -0.1524526,  1.30775225,  0.45538983,  -0.06631992, -1.75332725, 1.33493888,  0.47327688,
        0.36873096,  1.18358743,  -0.34640595, 1.22098756,  0.01946825,  -0.20238149, 0.43348005,
        -0.67991608, -0.83041084, 0.93537551,  0.70241445,  -0.5654031,  -1.30899191, -0.26735824,
        -0.52444768, 1.99097753,  1.86504853,  -0.26506025, 0.26236168,  0.43763575,  0.95300823,
        -1.02733946, -0.74655169, -0.5374338,  -0.28901565, -0.59789604, 0.5310151,   0.99125904,
        0.40609556,  -1.57175648, 0.22031412,  1.45862222,  0.53217483,  1.39087725,  1.00170159,
        -0.87175864, -1.7204628,  -1.72008383, -0.38656762, -0.01443311, 1.46645272,  -1.39995027,
        0.22505587,  -0.43461126, -0.05511411, -0.79950953, -0.01439556, 0.08795211,  1.18943918,
        -0.84079367, -1.73383629, -0.55662078, -0.30626822, -0.67339015, 0.44179603,  0.54316711,
        0.40899998,  -0.27831686, -1.11900508, -0.0881724,  0.35483059,  2.36277103,  -0.04765317,
        -0.36865309, 0.73814237,  1.47151589,  1.36546791,  -0.32649881, -1.0517807,  2.24768877,
        0.68883753,  0.58646208,  -0.91017133, -0.50462508, -0.4013325,  -0.72348958, -0.47368807,
        0.35285577,  -1.01817429, -0.5152272,  0.60321307,  0.43521205,  -0.23733577, 0.66427642,
        0.82949388,  0.82443929,  0.71550399,  0.34561086,  0.68570769,  -0.40718508, -1.20350206,
        0.15793853,  -2.31013632, -0.07934658, -0.09348056, 0.36576006,  2.46601582,  0.11090943,
        0.9144392,   0.56759721,  -0.22112127, -0.21955389, 0.72474903,  -1.28448462, 1.53285873,
        0.37437943,  0.31409341,  1.95433736,  0.91620457,  0.86205518,  1.24365854,  0.19248386,
        0.22526583,  0.13462132,  -0.27561715, -2.06446075, -0.02306402, -1.38278747, 1.1411345,
        1.31293464,  -1.86041689, 1.06763375,  -0.26541466, 1.4545635,   1.11430049,  -0.66491818,
        0.87101674,  0.67768967,  -1.02062869, -1.05031872, -2.2764678,  -2.0200038,  0.37592548,
        -0.26701379, -0.83388507, 0.19403623,  1.00968623,  0.11020003,  1.16736257,  -1.1160326,
        0.47346735,  0.6126079,   -0.19135755, 1.33624589,  -0.29802522, -0.57873946, -1.06555879,
        -0.20686582, 1.36892557,  -0.19937795, 0.8649236,   -1.40126073, 1.53441942,  0.34682792,
        -1.31724346, -1.32898355, 2.40126371,  0.07845283,  1.35732043,  -0.63678312, 0.39429256,
        -1.36487007, -0.31026676, -0.44981545, -0.28994772, -0.14657612, -1.75206447, -0.70612341,
        1.20071781,  -1.64647579, -0.7133292,  0.88494766,  0.52119428,  -2.77387547, 2.07681108,
        -0.90133125, 0.2847338,   0.6174528,   -0.20616426, -0.64263535, -1.08496261, 0.54275119,
        -0.88503587, 0.6629802,   1.47319221,  -1.05829155, -0.97027361, -0.93187737, -1.39954746,
        -0.52359426, -0.14743951, 1.51522756,  0.2078452,   -1.28156149, -1.19363916, -0.78680223,
        -0.89094824, 1.30212069,  -0.77974445, -0.58411664, 0.48764706,  -0.67132682};
    std::vector<float> c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
                            1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
                            1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
                            1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
                            1.33624589, 1.16736257, 0.6126079,  1.36892557, 2.40126371, 1.53441942,
                            0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
314
315
316
317
    migraph::shape a_shape{migraph::shape::float_type, {2, 3, 6, 6}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
    p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
    p.compile(migraph::cpu::cpu_target{});
318
319
320
321
322
323
324
325
326
327
328
329
    auto result = p.eval({});
    std::cout << result.get_shape() << std::endl;
    std::vector<float> results_vector(36);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    float tol = 1e-6;
    for(int i = 0; i < results_vector.size(); i++)
    {
        // std::cout << results_vector[i] << "          " << c[i] << std::endl;
        EXPECT(std::abs(results_vector[i] - c[i]) < tol);
    }
}

330
331
void softmax_test()
{
Paul's avatar
Paul committed
332
    migraph::program p;
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    std::vector<float> a = {
        -5.61869681e-01, 9.07827199e-01,  1.29255986e+00,  3.18533443e-02,  -1.22183852e-03,
        -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
        -9.20038462e-01, 8.47388089e-01,  2.51734018e-01,  1.50563884e+00,  2.23056650e+00,
        -6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00,  -2.51560897e-01,
        -8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
        3.23284641e-02,  -1.54700470e+00, 1.38096774e+00,  5.39869189e-01,  -7.56884992e-01,
        1.81503093e+00,  -2.11269641e+00, 1.92466557e+00,  1.77230799e+00,  2.21660900e+00,
        1.56777036e+00,  -2.08995026e-03, 3.50566894e-01,  -1.15042710e+00, -1.18577778e+00,
        8.90633047e-01,  -6.63949102e-02, 1.44661188e+00,  1.59215283e+00,  -2.56262213e-01,
        9.39079225e-01,  4.07298543e-02,  3.86590779e-01,  6.09607756e-01,  8.22331488e-01,
        -2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
        3.27092171e-01,  -1.33315325e+00, 3.62459183e-01,  3.74710828e-01,  -1.30302286e+00,
        1.79680198e-01,  -4.51832324e-01, 4.34282750e-01,  -7.09520102e-01, 6.20333970e-01,
        -1.28712380e+00, 2.04130828e-01,  -7.70607769e-01, 1.61889160e+00,  -1.50951004e+00,
        -4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
        -8.28408226e-02, 2.73412596e-02,  5.79780899e-03,  9.87900198e-02,  -7.95276761e-01,
        -1.38536084e+00, -6.63573861e-01, 3.89783204e-01,  -1.30670881e+00, -7.62425125e-01,
        -4.04883057e-01, 6.24344349e-01,  3.68128955e-01,  -1.01577950e+00, -3.06715906e-01,
        5.67961395e-01,  2.98198581e-01,  -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
        2.59172034e+00,  6.75538957e-01,  9.07939598e-02,  1.92257717e-01,  -1.21592450e+00,
        -2.73682117e-01, 1.25232983e+00,  -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
        3.10056299e-01,  1.41833842e+00,  -1.81386679e-01, 3.92868072e-01,  -8.14771175e-01,
        2.02392387e+00,  -9.42091495e-02, -3.77683818e-01, 2.05638766e+00,  2.93796062e-01,
        -6.02131486e-01, 2.70461679e-01,  -8.92358482e-01, 1.04388881e+00,  2.66154885e-01};

    std::vector<float> s = {
        0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
        0.18750034, 0.21905553, 0.27000085, 0.0547399,  0.56318235, 0.47422904, 0.78964758,
        0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684,  0.18766427, 0.15753111,
        0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
        0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
        0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
        0.3115395,  0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762,  0.4642328,
        0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
        0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
        0.17296699, 0.46923906, 0.06921105, 0.3570261,  0.4125829,  0.73165393, 0.15302512,
        0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
        0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659,  0.10663581,
        0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
        0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339,  0.49818122, 0.10656087,
        0.1813329,  0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
        0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
        0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
        0.42914796};

Paul's avatar
Paul committed
379
380
381
382
    migraph::shape a_shape{migraph::shape::float_type, {5, 3, 4, 2}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
    p.add_instruction(migraph::softmax{}, al);
    p.compile(migraph::cpu::cpu_target{});
383
384
    auto result = p.eval({});
    std::vector<float> results_vector(120);
385
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
386
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
387
388
}

389
390
void conv2d_test()
{
Paul's avatar
Paul committed
391
    migraph::program p;
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        2.82721668e-02,  6.44195229e-02,  1.53499246e-02,  1.72468081e-01,  -6.33238107e-02,
        9.49496776e-02,  1.40258059e-01,  -7.92879611e-02, -1.29301161e-01, 3.11307609e-03,
        -1.90624535e-01, 1.13238767e-01,  -2.80647576e-02, 3.12882811e-02,  -3.52091640e-02,
        3.33581865e-02,  6.43158704e-02,  7.40238279e-02,  -1.00106120e-01, -9.56912562e-02,
        1.44342467e-01,  9.40258950e-02,  6.36333972e-02,  1.66158378e-03,  -8.91554281e-02,
        2.58734226e-02,  1.70919895e-02,  1.78214177e-01,  8.84564668e-02,  8.98126513e-02,
        -1.63809001e-01, 1.37802169e-01,  1.66439757e-01,  -1.45631135e-02, 1.88469887e-04,
        4.76950556e-02,  -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01,
        1.76608220e-01,  -1.50728196e-01, 1.99946314e-02,  -5.88052124e-02, 1.31612435e-01,
        1.61106288e-02,  -1.35080189e-01, 1.49512306e-01,  3.86456847e-02,  1.29330024e-01,
        -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02};

    std::vector<float> s = {0.27039781,
                            0.19105849,
                            -0.06339942,
                            -0.65087199,
                            0.40867025,
                            0.05063812,
                            -0.14907975,
                            0.49018705,
                            -0.49197209,
                            0.33236548,
                            -0.39374301,
                            0.16012701,
                            0.06574871,
                            0.71606487,
                            -0.55201721,
Scott Thornton's avatar
Scott Thornton committed
436
                            -0.46427044};
Paul's avatar
Paul committed
437
438
    migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
Scott Thornton's avatar
Scott Thornton committed
439

Paul's avatar
Paul committed
440
441
    migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
    auto cl = p.add_literal(migraph::literal{c_shape, c});
Scott Thornton's avatar
Scott Thornton committed
442

Paul's avatar
Paul committed
443
444
    p.add_instruction(migraph::convolution{}, al, cl);
    p.compile(migraph::cpu::cpu_target{});
Scott Thornton's avatar
Scott Thornton committed
445
446
447
    auto result = p.eval({});

    std::vector<float> results_vector(16);
448
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
449
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
450
451
}

452
453
void conv2d_padding_test()
{
Paul's avatar
Paul committed
454
    migraph::program p;
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.16115488, -0.09800646, -0.05412646, 0.10475694,  0.00555485,  -0.12667653, 0.0458357,
        -0.02656217, -0.16338061, 0.15037455,  0.0102711,   0.01303349,  0.05242859,  0.02034754,
        0.04751867,  -0.17038961, -0.1434752,  -0.10770349, 0.05676742,  -0.15838449, 0.10128359,
        -0.18958683, 0.11954515,  0.10758857,  -0.01058291, -0.12797487, 0.08971019,  0.18793164,
        -0.00881396, -0.06588994, -0.13321903, -0.03300409, 0.01439607,  0.07618178,  -0.11556662,
        0.00764295,  0.12956454,  -0.08937147, -0.12763587, 0.04674943,  0.05765297,  0.11336918,
        0.14747436,  -0.06199479, -0.01166052, -0.12432006, -0.04494537, -0.17581205, 0.09475745,
        0.1149437,   -0.1014564,  0.0274073,   -0.01323579, -0.11092556};

    std::vector<float> s = {
        -0.0201216,  0.40407312,  -0.39005592, -0.0631946,  0.37963012,  -0.64611685, 0.1349397,
        -0.54113752, 0.28533003,  0.27667275,  -0.16442731, -0.181494,   0.30564839,  0.58744538,
        0.32015014,  0.24969585,  -0.27367792, -0.53308117, 0.41236052,  0.26136363,  -0.01489828,
        0.57652152,  -0.38506854, 0.119615,    0.0437076,   0.04779706,  0.57887721,  0.23126155,
        0.05695833,  -0.68200272, 0.02063358,  -0.10267162, 0.8062973,   -0.38149622, -0.40134856,
        -0.03353126, 0.38991132,  -0.3478111,  0.03661491,  0.25783631,  0.62772679,  -0.1961118,
        0.76423508,  -0.36241418, -0.20994355, -0.12368261, -0.9406727,  0.02340185,  -0.08793129,
        -0.02471633, -0.58163726, -0.02211772, -0.42014724, 0.77525634,  0.504951,    -0.20537445,
        -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194,  0.49579027,
        0.46527559};

Paul's avatar
Paul committed
493
494
    migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
Scott Thornton's avatar
Scott Thornton committed
495

Paul's avatar
Paul committed
496
497
    migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
    auto cl = p.add_literal(migraph::literal{c_shape, c});
Scott Thornton's avatar
Scott Thornton committed
498

Paul's avatar
Paul committed
499
500
    p.add_instruction(migraph::convolution{{{1, 1}}, {{1, 1}}}, al, cl);
    p.compile(migraph::cpu::cpu_target{});
Scott Thornton's avatar
Scott Thornton committed
501
502
503
    auto result = p.eval({});

    std::vector<float> results_vector(64);
504
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
505
    EXPECT(test::verify_range(results_vector, s));
506
507
}

508
509
void conv2d_padding_stride_test()
{
Paul's avatar
Paul committed
510
    migraph::program p;
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.14601797, -0.13000923, 0.06521662,  0.06178288,  -0.11083675, 0.10154136,  0.09990512,
        0.06030385,  -0.11374587, -0.17523311, -0.14344215, 0.17802463,  0.06300922,  -0.15325832,
        0.07066704,  0.05166031,  0.00615084,  -0.02606523, 0.08083995,  -0.17913306, 0.0624622,
        0.0735731,   -0.04198661, -0.0164391,  -0.06374192, 0.16569914,  0.10681538,  0.07370754,
        0.02802075,  0.00282027,  0.15104802,  -0.11084409, -0.00197773, 0.07924436,  0.03528272,
        0.04765259,  -0.15896152, 0.07917164,  0.12125669,  -0.1154705,  -0.11999125, 0.12749968,
        -0.06269585, 0.18658121,  -0.03944227, 0.0111798,   -0.17731084, 0.11789055,  -0.09982193,
        0.08142821,  0.0729029,   0.11303909,  0.12735154,  0.03885292};

    std::vector<float> s = {-0.20817225,
                            0.87965256,
                            0.14958936,
                            -1.24887264,
                            -0.06540672,
                            0.20778663,
                            0.40456355,
                            -0.99900877,
                            0.4917807,
                            0.1994698,
                            0.64205718,
                            0.37798831,
                            -0.25315839,
                            0.44276932,
                            -0.16138598,
                            0.79344082};

Paul's avatar
Paul committed
554
555
    migraph::shape a_shape{migraph::shape::float_type, {2, 3, 4, 4}};
    auto al = p.add_literal(migraph::literal{a_shape, a});
Scott Thornton's avatar
Scott Thornton committed
556

Paul's avatar
Paul committed
557
558
    migraph::shape c_shape{migraph::shape::float_type, {2, 3, 3, 3}};
    auto cl = p.add_literal(migraph::literal{c_shape, c});
Scott Thornton's avatar
Scott Thornton committed
559

Paul's avatar
Paul committed
560
561
    p.add_instruction(migraph::convolution{{{1, 1}}, {{2, 2}}}, al, cl);
    p.compile(migraph::cpu::cpu_target{});
Scott Thornton's avatar
Scott Thornton committed
562
563
564
    auto result = p.eval({});

    std::vector<float> results_vector(16);
565
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
566
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
567
}
568

569
570
void transpose_test()
{
Paul's avatar
Paul committed
571
    migraph::shape a_shape{migraph::shape::float_type, {1, 2, 2, 3}};
572
573
574
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

575
    {
Paul's avatar
Paul committed
576
577
        migraph::program p;
        auto l                    = p.add_literal(migraph::literal{a_shape, data});
Paul's avatar
Paul committed
578
        std::vector<int64_t> perm = {0, 3, 1, 2};
Paul's avatar
Paul committed
579
580
        p.add_instruction(migraph::transpose{perm}, l);
        p.compile(migraph::cpu::cpu_target{});
581
        auto result = p.eval({});
582

Paul's avatar
Paul committed
583
584
585
586
587
588
        result.visit([&](auto output) {
            std::vector<size_t> new_lens    = {1, 3, 2, 2};
            std::vector<size_t> new_strides = {12, 1, 6, 3};
            EXPECT(bool{output.get_shape().lens() == new_lens});
            EXPECT(bool{output.get_shape().strides() == new_strides});
        });
589
590
    }
    {
Paul's avatar
Paul committed
591
592
        migraph::program p;
        auto l                    = p.add_literal(migraph::literal{a_shape, data});
Paul's avatar
Paul committed
593
        std::vector<int64_t> perm = {0, 3, 1, 2};
Paul's avatar
Paul committed
594
595
596
        auto result               = p.add_instruction(migraph::transpose{perm}, l);
        p.add_instruction(migraph::contiguous{}, result);
        p.compile(migraph::cpu::cpu_target{});
597
598
599
600
601
602
603
        auto result2 = p.eval({});

        std::vector<float> results_vector(12);
        result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
        std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
        EXPECT(test::verify_range(results_vector, gold));
    }
604
605
}

Paul's avatar
Paul committed
606
607
void contiguous_test()
{
Paul's avatar
Paul committed
608
    migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
609
610
611
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

Paul's avatar
Paul committed
612
613
614
615
    migraph::program p;
    auto l = p.add_literal(migraph::literal{a_shape, data});
    p.add_instruction(migraph::contiguous{}, l);
    p.compile(migraph::cpu::cpu_target{});
616
617
618
    auto result = p.eval({});

    std::vector<float> results_vector(12);
619
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Paul's avatar
Paul committed
620
621
622
    std::vector<size_t> new_lens    = {1, 3, 2, 2};
    std::vector<size_t> new_strides = {12, 1, 6, 3};
    std::vector<float> gold         = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
623
    EXPECT(test::verify_range(results_vector, gold));
624
625
}

626
627
628
int main()
{
    exp_test();
629
630
631
    sin_test();
    cos_test();
    tan_test();
632
    add_test();
633
634
    broadcast_test();
    add_broadcast_test();
635
636
    sub_test();
    mul_test();
637
    gemm_test();
638
    reshape_test();
639
    transpose_test();
640
    contiguous_test();
641
    softmax_test();
Scott Thornton's avatar
Scott Thornton committed
642
    // maxpool_test();
Scott Thornton's avatar
Scott Thornton committed
643
644
645
    conv2d_test();
    conv2d_padding_test();
    conv2d_padding_stride_test();
646
    batch_norm_inference_test();
647
}