cpu_ops_test.cpp 26.9 KB
Newer Older
1
2
3
4
5
#include <iostream>
#include <vector>
#include <rtg/literal.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
6
#include "test.hpp"
Scott Thornton's avatar
Scott Thornton committed
7
8
#include "verify.hpp"

9
10
void exp_test()
{
11
12
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
13
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
14
15
16
17
    p.add_instruction(rtg::exp{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
18
19
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
Scott Thornton's avatar
Scott Thornton committed
20
    EXPECT(test::verify_range(results_vector, gold));
21
22
}

23
24
void sin_test()
{
25
26
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
27
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
28
29
30
31
    p.add_instruction(rtg::sin{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
32
33
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
Scott Thornton's avatar
Scott Thornton committed
34
    EXPECT(test::verify_range(results_vector, gold));
35
36
}

37
38
void cos_test()
{
39
40
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
41
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
42
43
44
45
    p.add_instruction(rtg::cos{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
46
47
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
Scott Thornton's avatar
Scott Thornton committed
48
    EXPECT(test::verify_range(results_vector, gold));
49
50
}

51
52
void tan_test()
{
53
54
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
55
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
56
57
58
59
    p.add_instruction(rtg::tan{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
60
61
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
Scott Thornton's avatar
Scott Thornton committed
62
    EXPECT(test::verify_range(results_vector, gold));
63
64
}

65
66
67
68
69
void add_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
70
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
71
72
73
74
75
76
77
78
79
    p.add_instruction(rtg::add{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0, 2, 4};
    EXPECT(test::verify_range(results_vector, gold));
}

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
void broadcast_test()
{
    rtg::program p;
    rtg::shape a_shape{rtg::shape::int32_type, {2,2}};
    std::vector<int32_t> a_data{0,0,0,0};
    rtg::shape b_shape{rtg::shape::int32_type, {2}};
    std::vector<int32_t> b_data{-2,-3};
    uint64_t axis = 0;
    auto l1 = p.add_literal(rtg::literal{a_shape, a_data});
    auto l2 = p.add_literal(rtg::literal{b_shape, b_data});
    p.add_instruction(rtg::broadcast{axis}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<int32_t> results_vector(4);

    // result.visit([&](auto output) {
    //     EXPECT(output(0,0) == -2); 
    //     EXPECT(output(0,1) == -2); 
    //     EXPECT(output(1,0) == -3); 
    //     EXPECT(output(1,1) == -3); 
    // });
}
void add_broadcast_test()
{
    rtg::program p;
    rtg::shape a_shape{rtg::shape::float_type, {2,2,3}};
    std::vector<float> a_data{0,1,2,3,4,5,6,7,8,9,10,11};
    rtg::shape b_shape{rtg::shape::float_type, {2,2}};
    std::vector<float> b_data{0,-1,-2,-3};
    uint64_t axis = 0;
    auto l1 = p.add_literal(rtg::literal{a_shape, a_data});
    auto l2 = p.add_literal(rtg::literal{b_shape, b_data});
    auto l3 = p.add_instruction(rtg::broadcast{axis}, l1, l2);
    p.add_instruction(rtg::add{}, l1, l3);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(12);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0,1,2,2,3,4,4,5,6,6,7,8};
    EXPECT(test::verify_range(results_vector, gold));
}

122
123
124
125
126
void sub_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
127
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    p.add_instruction(rtg::sub{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-2, -2, -2};
    EXPECT(test::verify_range(results_vector, gold));
}

void mul_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
142
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    p.add_instruction(rtg::mul{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1, 0, 3};
    EXPECT(test::verify_range(results_vector, gold));
}

void div_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}});
Scott Thornton's avatar
Scott Thornton committed
157
    auto l2 = p.add_literal(rtg::literal{s, {1.0f, 2.0f, 4.0f}});
158
159
160
161
162
163
164
165
166
    p.add_instruction(rtg::div{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.f, 0.25f, 0.25f};
    EXPECT(test::verify_range(results_vector, gold));
}

167
168
169
void reshape_test()
{
    rtg::shape a_shape{rtg::shape::float_type, {24, 1, 1, 1}};
170
171
172
173
    std::vector<float> data(24);
    std::iota(data.begin(), data.end(), -3);
    {
        rtg::program p;
174
175
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {8, 3, 1, 1};
176
177
178
179
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
180
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
181
        EXPECT(test::verify_range(results_vector, data));
182
183
184
    }
    {
        rtg::program p;
185
186
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
187
188
189
190
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
191
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
192
        EXPECT(test::verify_range(results_vector, data));
193
194
195
    }
    {
        rtg::program p;
196
197
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
198
199
200
201
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
202
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
203
        EXPECT(test::verify_range(results_vector, data));
204
205
206
    }
}

207
208
void gemm_test()
{
209
    rtg::program p;
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397,  0.75402161,  -0.505885,
                            1.33628943,  -0.11413,   -0.31270559, 1.59336732,  -0.19361027,
                            -0.91620867, 0.40108416, -0.06969921, 0.68483471,  -0.39906632,
                            -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> b = {6.09568541e-01,
                            -6.10527007e-01,
                            3.66646462e-01,
                            1.18951101e-01,
                            5.58777432e-01,
                            -3.21296298e-01,
                            -5.95997198e-01,
                            -5.01425721e-01,
                            -2.84606807e-01,
                            -5.73673557e-01,
                            -8.99430260e-01,
                            -4.25103093e-01,
                            1.53027987e+00,
                            -3.81407415e-04,
                            -3.29650255e-01};
    std::vector<float> c = {-1.56327541e+00,
                            -7.09570140e-01,
                            -5.37424982e-01,
                            -2.22994831e-01,
                            -2.15586437e+00,
                            2.09177941e-03,
                            -1.47279677e+00,
                            2.02627040e-01,
                            -6.04527691e-01,
                            -1.29885596e+00,
                            2.16294914e+00,
                            -1.48101497e-01};
    rtg::shape a_shape{rtg::shape::float_type, {4, 5}};
Scott Thornton's avatar
Scott Thornton committed
242
    auto al = p.add_literal(rtg::literal{a_shape, a});
243
    rtg::shape b_shape{rtg::shape::float_type, {5, 3}};
Scott Thornton's avatar
Scott Thornton committed
244
245
    auto bl = p.add_literal(rtg::literal{b_shape, b});
    p.add_instruction(rtg::gemm{}, al, bl);
246
247
248
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(12);
249
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
250
    float tol = 1e-6;
251
252
253
    for(int i = 0; i < results_vector.size(); i++)
    {
        EXPECT(std::abs(results_vector[i] - c[i]) < tol);
254
255
256
    }
}

257
258
void softmax_test()
{
259
    rtg::program p;
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    std::vector<float> a = {
        -5.61869681e-01, 9.07827199e-01,  1.29255986e+00,  3.18533443e-02,  -1.22183852e-03,
        -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
        -9.20038462e-01, 8.47388089e-01,  2.51734018e-01,  1.50563884e+00,  2.23056650e+00,
        -6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00,  -2.51560897e-01,
        -8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
        3.23284641e-02,  -1.54700470e+00, 1.38096774e+00,  5.39869189e-01,  -7.56884992e-01,
        1.81503093e+00,  -2.11269641e+00, 1.92466557e+00,  1.77230799e+00,  2.21660900e+00,
        1.56777036e+00,  -2.08995026e-03, 3.50566894e-01,  -1.15042710e+00, -1.18577778e+00,
        8.90633047e-01,  -6.63949102e-02, 1.44661188e+00,  1.59215283e+00,  -2.56262213e-01,
        9.39079225e-01,  4.07298543e-02,  3.86590779e-01,  6.09607756e-01,  8.22331488e-01,
        -2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
        3.27092171e-01,  -1.33315325e+00, 3.62459183e-01,  3.74710828e-01,  -1.30302286e+00,
        1.79680198e-01,  -4.51832324e-01, 4.34282750e-01,  -7.09520102e-01, 6.20333970e-01,
        -1.28712380e+00, 2.04130828e-01,  -7.70607769e-01, 1.61889160e+00,  -1.50951004e+00,
        -4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
        -8.28408226e-02, 2.73412596e-02,  5.79780899e-03,  9.87900198e-02,  -7.95276761e-01,
        -1.38536084e+00, -6.63573861e-01, 3.89783204e-01,  -1.30670881e+00, -7.62425125e-01,
        -4.04883057e-01, 6.24344349e-01,  3.68128955e-01,  -1.01577950e+00, -3.06715906e-01,
        5.67961395e-01,  2.98198581e-01,  -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
        2.59172034e+00,  6.75538957e-01,  9.07939598e-02,  1.92257717e-01,  -1.21592450e+00,
        -2.73682117e-01, 1.25232983e+00,  -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
        3.10056299e-01,  1.41833842e+00,  -1.81386679e-01, 3.92868072e-01,  -8.14771175e-01,
        2.02392387e+00,  -9.42091495e-02, -3.77683818e-01, 2.05638766e+00,  2.93796062e-01,
        -6.02131486e-01, 2.70461679e-01,  -8.92358482e-01, 1.04388881e+00,  2.66154885e-01};

    std::vector<float> s = {
        0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
        0.18750034, 0.21905553, 0.27000085, 0.0547399,  0.56318235, 0.47422904, 0.78964758,
        0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684,  0.18766427, 0.15753111,
        0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
        0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
        0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
        0.3115395,  0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762,  0.4642328,
        0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
        0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
        0.17296699, 0.46923906, 0.06921105, 0.3570261,  0.4125829,  0.73165393, 0.15302512,
        0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
        0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659,  0.10663581,
        0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
        0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339,  0.49818122, 0.10656087,
        0.1813329,  0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
        0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
        0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
        0.42914796};

    rtg::shape a_shape{rtg::shape::float_type, {5, 3, 4, 2}};
Scott Thornton's avatar
Scott Thornton committed
307
308
    auto al = p.add_literal(rtg::literal{a_shape, a});
    p.add_instruction(rtg::softmax{}, al);
309
310
311
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(120);
312
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
313
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
314
315
}

316
317
void conv2d_test()
{
Scott Thornton's avatar
Scott Thornton committed
318
    rtg::program p;
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        2.82721668e-02,  6.44195229e-02,  1.53499246e-02,  1.72468081e-01,  -6.33238107e-02,
        9.49496776e-02,  1.40258059e-01,  -7.92879611e-02, -1.29301161e-01, 3.11307609e-03,
        -1.90624535e-01, 1.13238767e-01,  -2.80647576e-02, 3.12882811e-02,  -3.52091640e-02,
        3.33581865e-02,  6.43158704e-02,  7.40238279e-02,  -1.00106120e-01, -9.56912562e-02,
        1.44342467e-01,  9.40258950e-02,  6.36333972e-02,  1.66158378e-03,  -8.91554281e-02,
        2.58734226e-02,  1.70919895e-02,  1.78214177e-01,  8.84564668e-02,  8.98126513e-02,
        -1.63809001e-01, 1.37802169e-01,  1.66439757e-01,  -1.45631135e-02, 1.88469887e-04,
        4.76950556e-02,  -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01,
        1.76608220e-01,  -1.50728196e-01, 1.99946314e-02,  -5.88052124e-02, 1.31612435e-01,
        1.61106288e-02,  -1.35080189e-01, 1.49512306e-01,  3.86456847e-02,  1.29330024e-01,
        -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02};

    std::vector<float> s = {0.27039781,
                            0.19105849,
                            -0.06339942,
                            -0.65087199,
                            0.40867025,
                            0.05063812,
                            -0.14907975,
                            0.49018705,
                            -0.49197209,
                            0.33236548,
                            -0.39374301,
                            0.16012701,
                            0.06574871,
                            0.71606487,
                            -0.55201721,
Scott Thornton's avatar
Scott Thornton committed
363
                            -0.46427044};
364
    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
365
366
    auto al = p.add_literal(rtg::literal{a_shape, a});

367
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
368
369
370
371
372
373
374
    auto cl = p.add_literal(rtg::literal{c_shape, c});

    p.add_instruction(rtg::convolution{}, al, cl);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(16);
375
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
376
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
377
378
}

379
380
void conv2d_padding_test()
{
Scott Thornton's avatar
Scott Thornton committed
381
    rtg::program p;
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.16115488, -0.09800646, -0.05412646, 0.10475694,  0.00555485,  -0.12667653, 0.0458357,
        -0.02656217, -0.16338061, 0.15037455,  0.0102711,   0.01303349,  0.05242859,  0.02034754,
        0.04751867,  -0.17038961, -0.1434752,  -0.10770349, 0.05676742,  -0.15838449, 0.10128359,
        -0.18958683, 0.11954515,  0.10758857,  -0.01058291, -0.12797487, 0.08971019,  0.18793164,
        -0.00881396, -0.06588994, -0.13321903, -0.03300409, 0.01439607,  0.07618178,  -0.11556662,
        0.00764295,  0.12956454,  -0.08937147, -0.12763587, 0.04674943,  0.05765297,  0.11336918,
        0.14747436,  -0.06199479, -0.01166052, -0.12432006, -0.04494537, -0.17581205, 0.09475745,
        0.1149437,   -0.1014564,  0.0274073,   -0.01323579, -0.11092556};

    std::vector<float> s = {
        -0.0201216,  0.40407312,  -0.39005592, -0.0631946,  0.37963012,  -0.64611685, 0.1349397,
        -0.54113752, 0.28533003,  0.27667275,  -0.16442731, -0.181494,   0.30564839,  0.58744538,
        0.32015014,  0.24969585,  -0.27367792, -0.53308117, 0.41236052,  0.26136363,  -0.01489828,
        0.57652152,  -0.38506854, 0.119615,    0.0437076,   0.04779706,  0.57887721,  0.23126155,
        0.05695833,  -0.68200272, 0.02063358,  -0.10267162, 0.8062973,   -0.38149622, -0.40134856,
        -0.03353126, 0.38991132,  -0.3478111,  0.03661491,  0.25783631,  0.62772679,  -0.1961118,
        0.76423508,  -0.36241418, -0.20994355, -0.12368261, -0.9406727,  0.02340185,  -0.08793129,
        -0.02471633, -0.58163726, -0.02211772, -0.42014724, 0.77525634,  0.504951,    -0.20537445,
        -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194,  0.49579027,
        0.46527559};

    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
421
422
    auto al = p.add_literal(rtg::literal{a_shape, a});

423
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
424
425
    auto cl = p.add_literal(rtg::literal{c_shape, c});

426
    p.add_instruction(rtg::convolution{{{1, 1}}, {{1, 1}}}, al, cl);
Scott Thornton's avatar
Scott Thornton committed
427
428
429
430
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(64);
431
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
432
    EXPECT(test::verify_range(results_vector, s));
433
434
}

435
436
void conv2d_padding_stride_test()
{
Scott Thornton's avatar
Scott Thornton committed
437
    rtg::program p;
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.14601797, -0.13000923, 0.06521662,  0.06178288,  -0.11083675, 0.10154136,  0.09990512,
        0.06030385,  -0.11374587, -0.17523311, -0.14344215, 0.17802463,  0.06300922,  -0.15325832,
        0.07066704,  0.05166031,  0.00615084,  -0.02606523, 0.08083995,  -0.17913306, 0.0624622,
        0.0735731,   -0.04198661, -0.0164391,  -0.06374192, 0.16569914,  0.10681538,  0.07370754,
        0.02802075,  0.00282027,  0.15104802,  -0.11084409, -0.00197773, 0.07924436,  0.03528272,
        0.04765259,  -0.15896152, 0.07917164,  0.12125669,  -0.1154705,  -0.11999125, 0.12749968,
        -0.06269585, 0.18658121,  -0.03944227, 0.0111798,   -0.17731084, 0.11789055,  -0.09982193,
        0.08142821,  0.0729029,   0.11303909,  0.12735154,  0.03885292};

    std::vector<float> s = {-0.20817225,
                            0.87965256,
                            0.14958936,
                            -1.24887264,
                            -0.06540672,
                            0.20778663,
                            0.40456355,
                            -0.99900877,
                            0.4917807,
                            0.1994698,
                            0.64205718,
                            0.37798831,
                            -0.25315839,
                            0.44276932,
                            -0.16138598,
                            0.79344082};

    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
482
483
    auto al = p.add_literal(rtg::literal{a_shape, a});

484
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
485
486
    auto cl = p.add_literal(rtg::literal{c_shape, c});

487
    p.add_instruction(rtg::convolution{{{1, 1}}, {{2, 2}}}, al, cl);
Scott Thornton's avatar
Scott Thornton committed
488
489
490
491
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(16);
492
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
493
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
494
}
495

496
497
void transpose_test()
{
Paul's avatar
Paul committed
498
    rtg::shape a_shape{rtg::shape::float_type, {1, 2, 2, 3}};
499
500
501
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

502
503
    {
        rtg::program p;
Paul's avatar
Paul committed
504
505
        auto l                    = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> perm = {0, 3, 1, 2};
506
507
508
        p.add_instruction(rtg::transpose{perm}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
509

Paul's avatar
Paul committed
510
511
512
513
514
515
        result.visit([&](auto output) {
            std::vector<size_t> new_lens    = {1, 3, 2, 2};
            std::vector<size_t> new_strides = {12, 1, 6, 3};
            EXPECT(bool{output.get_shape().lens() == new_lens});
            EXPECT(bool{output.get_shape().strides() == new_strides});
        });
516
517
518
    }
    {
        rtg::program p;
Paul's avatar
Paul committed
519
520
521
        auto l                    = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> perm = {0, 3, 1, 2};
        auto result               = p.add_instruction(rtg::transpose{perm}, l);
522
523
524
525
526
527
528
529
530
        p.add_instruction(rtg::contiguous{}, result);
        p.compile(rtg::cpu::cpu_target{});
        auto result2 = p.eval({});

        std::vector<float> results_vector(12);
        result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
        std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
        EXPECT(test::verify_range(results_vector, gold));
    }
531
532
}

Paul's avatar
Paul committed
533
534
535
void contiguous_test()
{
    rtg::shape a_shape{rtg::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
536
537
538
539
540
541
542
543
544
545
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

    rtg::program p;
    auto l = p.add_literal(rtg::literal{a_shape, data});
    p.add_instruction(rtg::contiguous{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(12);
546
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Paul's avatar
Paul committed
547
548
549
    std::vector<size_t> new_lens    = {1, 3, 2, 2};
    std::vector<size_t> new_strides = {12, 1, 6, 3};
    std::vector<float> gold         = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
550
    EXPECT(test::verify_range(results_vector, gold));
551
552
}

553
554
555
int main()
{
    exp_test();
556
557
558
    sin_test();
    cos_test();
    tan_test();
559
    add_test();
560
561
    broadcast_test();
    add_broadcast_test();
562
563
    sub_test();
    mul_test();
564
    gemm_test();
565
    reshape_test();
566
    transpose_test();
567
    contiguous_test();
568
    softmax_test();
Scott Thornton's avatar
Scott Thornton committed
569
570
571
    conv2d_test();
    conv2d_padding_test();
    conv2d_padding_stride_test();
572
}