cpu_ops_test.cpp 26.8 KB
Newer Older
1
2
3
4
5
#include <iostream>
#include <vector>
#include <rtg/literal.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
6
#include "test.hpp"
Scott Thornton's avatar
Scott Thornton committed
7
8
#include "verify.hpp"

9
10
void exp_test()
{
11
12
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
13
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
14
15
16
17
    p.add_instruction(rtg::exp{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
18
19
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.36787944f, 1.f, 2.71828183f};
Scott Thornton's avatar
Scott Thornton committed
20
    EXPECT(test::verify_range(results_vector, gold));
21
22
}

23
24
void sin_test()
{
25
26
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
27
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
28
29
30
31
    p.add_instruction(rtg::sin{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
32
33
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-0.84147098f, 0.f, 0.84147098f};
Scott Thornton's avatar
Scott Thornton committed
34
    EXPECT(test::verify_range(results_vector, gold));
35
36
}

37
38
void cos_test()
{
39
40
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
41
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
42
43
44
45
    p.add_instruction(rtg::cos{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
46
47
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0.54030231f, 1.f, 0.54030231f};
Scott Thornton's avatar
Scott Thornton committed
48
    EXPECT(test::verify_range(results_vector, gold));
49
50
}

51
52
void tan_test()
{
53
54
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
55
    auto l = p.add_literal(rtg::literal{s, {-1, 0, 1}});
56
57
58
59
    p.add_instruction(rtg::tan{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
60
61
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.55740772f, 0.0f, 1.55740772f};
Scott Thornton's avatar
Scott Thornton committed
62
    EXPECT(test::verify_range(results_vector, gold));
63
64
}

65
66
67
68
69
void add_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
70
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
71
72
73
74
75
76
77
78
79
    p.add_instruction(rtg::add{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {0, 2, 4};
    EXPECT(test::verify_range(results_vector, gold));
}

80
81
82
void broadcast_test()
{
    rtg::program p;
83
84
    rtg::shape a_shape{rtg::shape::int32_type, {2, 2}};
    std::vector<int32_t> a_data{0, 0, 0, 0};
85
    rtg::shape b_shape{rtg::shape::int32_type, {2}};
86
    std::vector<int32_t> b_data{-2, -3};
87
    uint64_t axis = 0;
88
89
    auto l1       = p.add_literal(rtg::literal{a_shape, a_data});
    auto l2       = p.add_literal(rtg::literal{b_shape, b_data});
90
91
92
93
94
95
96
97
    p.add_instruction(rtg::broadcast{axis}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<int32_t> results_vector(4);
}
void add_broadcast_test()
{
    rtg::program p;
98
99
100
101
    rtg::shape a_shape{rtg::shape::float_type, {2, 2, 3}};
    std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
    rtg::shape b_shape{rtg::shape::float_type, {2, 2}};
    std::vector<float> b_data{0, -1, -2, -3};
102
    uint64_t axis = 0;
103
104
105
    auto l1       = p.add_literal(rtg::literal{a_shape, a_data});
    auto l2       = p.add_literal(rtg::literal{b_shape, b_data});
    auto l3       = p.add_instruction(rtg::broadcast{axis}, l1, l2);
106
107
108
    p.add_instruction(rtg::add{}, l1, l3);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
109
    EXPECT(result.get_shape().packed());
110
111
    std::vector<float> results_vector(12);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
112
    std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
113
114
115
    EXPECT(test::verify_range(results_vector, gold));
}

116
117
118
119
120
void sub_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
121
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    p.add_instruction(rtg::sub{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-2, -2, -2};
    EXPECT(test::verify_range(results_vector, gold));
}

void mul_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1, 0, 1}});
Scott Thornton's avatar
Scott Thornton committed
136
    auto l2 = p.add_literal(rtg::literal{s, {1, 2, 3}});
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    p.add_instruction(rtg::mul{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1, 0, 3};
    EXPECT(test::verify_range(results_vector, gold));
}

void div_test()
{
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l1 = p.add_literal(rtg::literal{s, {-1.0f, 0.5f, 1.0f}});
Scott Thornton's avatar
Scott Thornton committed
151
    auto l2 = p.add_literal(rtg::literal{s, {1.0f, 2.0f, 4.0f}});
152
153
154
155
156
157
158
159
160
    p.add_instruction(rtg::div{}, l1, l2);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
    std::vector<float> gold = {-1.f, 0.25f, 0.25f};
    EXPECT(test::verify_range(results_vector, gold));
}

161
162
163
void reshape_test()
{
    rtg::shape a_shape{rtg::shape::float_type, {24, 1, 1, 1}};
164
165
166
167
    std::vector<float> data(24);
    std::iota(data.begin(), data.end(), -3);
    {
        rtg::program p;
168
169
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {8, 3, 1, 1};
170
171
172
173
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
174
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
175
        EXPECT(test::verify_range(results_vector, data));
176
177
178
    }
    {
        rtg::program p;
179
180
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
181
182
183
184
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
185
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
186
        EXPECT(test::verify_range(results_vector, data));
187
188
189
    }
    {
        rtg::program p;
190
191
        auto l                         = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> new_shape = {1, 3, 4, 2};
192
193
194
195
        p.add_instruction(rtg::reshape{new_shape}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
        std::vector<float> results_vector(3);
196
        result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
197
        EXPECT(test::verify_range(results_vector, data));
198
199
200
    }
}

201
202
void gemm_test()
{
203
    rtg::program p;
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397,  0.75402161,  -0.505885,
                            1.33628943,  -0.11413,   -0.31270559, 1.59336732,  -0.19361027,
                            -0.91620867, 0.40108416, -0.06969921, 0.68483471,  -0.39906632,
                            -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> b = {6.09568541e-01,
                            -6.10527007e-01,
                            3.66646462e-01,
                            1.18951101e-01,
                            5.58777432e-01,
                            -3.21296298e-01,
                            -5.95997198e-01,
                            -5.01425721e-01,
                            -2.84606807e-01,
                            -5.73673557e-01,
                            -8.99430260e-01,
                            -4.25103093e-01,
                            1.53027987e+00,
                            -3.81407415e-04,
                            -3.29650255e-01};
    std::vector<float> c = {-1.56327541e+00,
                            -7.09570140e-01,
                            -5.37424982e-01,
                            -2.22994831e-01,
                            -2.15586437e+00,
                            2.09177941e-03,
                            -1.47279677e+00,
                            2.02627040e-01,
                            -6.04527691e-01,
                            -1.29885596e+00,
                            2.16294914e+00,
                            -1.48101497e-01};
    rtg::shape a_shape{rtg::shape::float_type, {4, 5}};
Scott Thornton's avatar
Scott Thornton committed
236
    auto al = p.add_literal(rtg::literal{a_shape, a});
237
    rtg::shape b_shape{rtg::shape::float_type, {5, 3}};
Scott Thornton's avatar
Scott Thornton committed
238
239
    auto bl = p.add_literal(rtg::literal{b_shape, b});
    p.add_instruction(rtg::gemm{}, al, bl);
240
241
242
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(12);
243
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
244
    float tol = 1e-6;
245
246
247
    for(int i = 0; i < results_vector.size(); i++)
    {
        EXPECT(std::abs(results_vector[i] - c[i]) < tol);
248
249
250
    }
}

251
252
void softmax_test()
{
253
    rtg::program p;
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    std::vector<float> a = {
        -5.61869681e-01, 9.07827199e-01,  1.29255986e+00,  3.18533443e-02,  -1.22183852e-03,
        -2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
        -9.20038462e-01, 8.47388089e-01,  2.51734018e-01,  1.50563884e+00,  2.23056650e+00,
        -6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00,  -2.51560897e-01,
        -8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
        3.23284641e-02,  -1.54700470e+00, 1.38096774e+00,  5.39869189e-01,  -7.56884992e-01,
        1.81503093e+00,  -2.11269641e+00, 1.92466557e+00,  1.77230799e+00,  2.21660900e+00,
        1.56777036e+00,  -2.08995026e-03, 3.50566894e-01,  -1.15042710e+00, -1.18577778e+00,
        8.90633047e-01,  -6.63949102e-02, 1.44661188e+00,  1.59215283e+00,  -2.56262213e-01,
        9.39079225e-01,  4.07298543e-02,  3.86590779e-01,  6.09607756e-01,  8.22331488e-01,
        -2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
        3.27092171e-01,  -1.33315325e+00, 3.62459183e-01,  3.74710828e-01,  -1.30302286e+00,
        1.79680198e-01,  -4.51832324e-01, 4.34282750e-01,  -7.09520102e-01, 6.20333970e-01,
        -1.28712380e+00, 2.04130828e-01,  -7.70607769e-01, 1.61889160e+00,  -1.50951004e+00,
        -4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
        -8.28408226e-02, 2.73412596e-02,  5.79780899e-03,  9.87900198e-02,  -7.95276761e-01,
        -1.38536084e+00, -6.63573861e-01, 3.89783204e-01,  -1.30670881e+00, -7.62425125e-01,
        -4.04883057e-01, 6.24344349e-01,  3.68128955e-01,  -1.01577950e+00, -3.06715906e-01,
        5.67961395e-01,  2.98198581e-01,  -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
        2.59172034e+00,  6.75538957e-01,  9.07939598e-02,  1.92257717e-01,  -1.21592450e+00,
        -2.73682117e-01, 1.25232983e+00,  -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
        3.10056299e-01,  1.41833842e+00,  -1.81386679e-01, 3.92868072e-01,  -8.14771175e-01,
        2.02392387e+00,  -9.42091495e-02, -3.77683818e-01, 2.05638766e+00,  2.93796062e-01,
        -6.02131486e-01, 2.70461679e-01,  -8.92358482e-01, 1.04388881e+00,  2.66154885e-01};

    std::vector<float> s = {
        0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
        0.18750034, 0.21905553, 0.27000085, 0.0547399,  0.56318235, 0.47422904, 0.78964758,
        0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684,  0.18766427, 0.15753111,
        0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
        0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
        0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
        0.3115395,  0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762,  0.4642328,
        0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
        0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
        0.17296699, 0.46923906, 0.06921105, 0.3570261,  0.4125829,  0.73165393, 0.15302512,
        0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
        0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659,  0.10663581,
        0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
        0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339,  0.49818122, 0.10656087,
        0.1813329,  0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
        0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
        0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
        0.42914796};

    rtg::shape a_shape{rtg::shape::float_type, {5, 3, 4, 2}};
Scott Thornton's avatar
Scott Thornton committed
301
302
    auto al = p.add_literal(rtg::literal{a_shape, a});
    p.add_instruction(rtg::softmax{}, al);
303
304
305
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(120);
306
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
307
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
308
309
}

310
311
void conv2d_test()
{
Scott Thornton's avatar
Scott Thornton committed
312
    rtg::program p;
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        2.82721668e-02,  6.44195229e-02,  1.53499246e-02,  1.72468081e-01,  -6.33238107e-02,
        9.49496776e-02,  1.40258059e-01,  -7.92879611e-02, -1.29301161e-01, 3.11307609e-03,
        -1.90624535e-01, 1.13238767e-01,  -2.80647576e-02, 3.12882811e-02,  -3.52091640e-02,
        3.33581865e-02,  6.43158704e-02,  7.40238279e-02,  -1.00106120e-01, -9.56912562e-02,
        1.44342467e-01,  9.40258950e-02,  6.36333972e-02,  1.66158378e-03,  -8.91554281e-02,
        2.58734226e-02,  1.70919895e-02,  1.78214177e-01,  8.84564668e-02,  8.98126513e-02,
        -1.63809001e-01, 1.37802169e-01,  1.66439757e-01,  -1.45631135e-02, 1.88469887e-04,
        4.76950556e-02,  -1.91969007e-01, -1.76233292e-01, -7.70473927e-02, 1.14828631e-01,
        1.76608220e-01,  -1.50728196e-01, 1.99946314e-02,  -5.88052124e-02, 1.31612435e-01,
        1.61106288e-02,  -1.35080189e-01, 1.49512306e-01,  3.86456847e-02,  1.29330024e-01,
        -3.22975963e-02, -5.60784787e-02, -5.41997552e-02, 4.78562862e-02};

    std::vector<float> s = {0.27039781,
                            0.19105849,
                            -0.06339942,
                            -0.65087199,
                            0.40867025,
                            0.05063812,
                            -0.14907975,
                            0.49018705,
                            -0.49197209,
                            0.33236548,
                            -0.39374301,
                            0.16012701,
                            0.06574871,
                            0.71606487,
                            -0.55201721,
Scott Thornton's avatar
Scott Thornton committed
357
                            -0.46427044};
358
    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
359
360
    auto al = p.add_literal(rtg::literal{a_shape, a});

361
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
362
363
364
365
366
367
368
    auto cl = p.add_literal(rtg::literal{c_shape, c});

    p.add_instruction(rtg::convolution{}, al, cl);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(16);
369
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
370
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
371
372
}

373
374
void conv2d_padding_test()
{
Scott Thornton's avatar
Scott Thornton committed
375
    rtg::program p;
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.16115488, -0.09800646, -0.05412646, 0.10475694,  0.00555485,  -0.12667653, 0.0458357,
        -0.02656217, -0.16338061, 0.15037455,  0.0102711,   0.01303349,  0.05242859,  0.02034754,
        0.04751867,  -0.17038961, -0.1434752,  -0.10770349, 0.05676742,  -0.15838449, 0.10128359,
        -0.18958683, 0.11954515,  0.10758857,  -0.01058291, -0.12797487, 0.08971019,  0.18793164,
        -0.00881396, -0.06588994, -0.13321903, -0.03300409, 0.01439607,  0.07618178,  -0.11556662,
        0.00764295,  0.12956454,  -0.08937147, -0.12763587, 0.04674943,  0.05765297,  0.11336918,
        0.14747436,  -0.06199479, -0.01166052, -0.12432006, -0.04494537, -0.17581205, 0.09475745,
        0.1149437,   -0.1014564,  0.0274073,   -0.01323579, -0.11092556};

    std::vector<float> s = {
        -0.0201216,  0.40407312,  -0.39005592, -0.0631946,  0.37963012,  -0.64611685, 0.1349397,
        -0.54113752, 0.28533003,  0.27667275,  -0.16442731, -0.181494,   0.30564839,  0.58744538,
        0.32015014,  0.24969585,  -0.27367792, -0.53308117, 0.41236052,  0.26136363,  -0.01489828,
        0.57652152,  -0.38506854, 0.119615,    0.0437076,   0.04779706,  0.57887721,  0.23126155,
        0.05695833,  -0.68200272, 0.02063358,  -0.10267162, 0.8062973,   -0.38149622, -0.40134856,
        -0.03353126, 0.38991132,  -0.3478111,  0.03661491,  0.25783631,  0.62772679,  -0.1961118,
        0.76423508,  -0.36241418, -0.20994355, -0.12368261, -0.9406727,  0.02340185,  -0.08793129,
        -0.02471633, -0.58163726, -0.02211772, -0.42014724, 0.77525634,  0.504951,    -0.20537445,
        -0.20369984, -0.83037728, -1.40423918, -0.46160448, -0.22944322, 0.36074194,  0.49579027,
        0.46527559};

    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
415
416
    auto al = p.add_literal(rtg::literal{a_shape, a});

417
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
418
419
    auto cl = p.add_literal(rtg::literal{c_shape, c});

420
    p.add_instruction(rtg::convolution{{{1, 1}}, {{1, 1}}}, al, cl);
Scott Thornton's avatar
Scott Thornton committed
421
422
423
424
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(64);
425
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
426
    EXPECT(test::verify_range(results_vector, s));
427
428
}

429
430
void conv2d_padding_stride_test()
{
Scott Thornton's avatar
Scott Thornton committed
431
    rtg::program p;
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    std::vector<float> a = {
        2.71567607,  -0.9960829,  0.91671127,  0.28140706,  0.63235772,  0.08077253,  0.80927712,
        -0.59108931, -1.05421555, -2.76622486, -0.85044265, -0.52049929, 0.67726439,  -0.65290606,
        0.02345525,  -0.33579525, 0.38901961,  1.05473483,  -1.31188095, 1.8963089,   -0.07265259,
        0.947339,    0.41949373,  -0.70814759, 0.25892952,  1.07311416,  1.2571274,   -0.62318051,
        -0.19951548, -0.94232577, -0.29393643, 0.42292568,  -0.80230367, 1.40909171,  0.63617158,
        0.13900366,  1.09253144,  -0.15265895, 1.54781747,  0.72780299,  1.09189606,  -0.38068101,
        0.97057933,  -0.58958799, 1.56188643,  0.21474874,  0.58725154,  -1.27097559, -0.03024297,
        1.09437096,  -0.4897908,  0.34838957,  -1.31042492, -1.69069934, 0.86956722,  -0.40457946,
        0.46691212,  1.29273605,  0.26464137,  0.22073045,  -1.02178168, 0.22163901,  -1.84387338,
        0.75522131,  -0.45775682, -0.42241111, -1.50944722, 1.07256448,  -1.95876884, -0.28106022,
        0.3341668,   2.13129425,  -1.14728117, -1.06555498, -0.298444,   -0.88322699, -0.65866792,
        -2.06007552, 0.01374334,  0.45612028,  0.52715492,  1.01914406,  -1.72659791, 0.80650896,
        0.16860051,  2.24112225,  -0.78620857, 0.36566174,  -0.07020134, -0.47976932, -0.68230027,
        -0.94711417, -0.54506505, 1.66504931,  -0.71860826, 0.61132306};

    std::vector<float> c = {
        -0.14601797, -0.13000923, 0.06521662,  0.06178288,  -0.11083675, 0.10154136,  0.09990512,
        0.06030385,  -0.11374587, -0.17523311, -0.14344215, 0.17802463,  0.06300922,  -0.15325832,
        0.07066704,  0.05166031,  0.00615084,  -0.02606523, 0.08083995,  -0.17913306, 0.0624622,
        0.0735731,   -0.04198661, -0.0164391,  -0.06374192, 0.16569914,  0.10681538,  0.07370754,
        0.02802075,  0.00282027,  0.15104802,  -0.11084409, -0.00197773, 0.07924436,  0.03528272,
        0.04765259,  -0.15896152, 0.07917164,  0.12125669,  -0.1154705,  -0.11999125, 0.12749968,
        -0.06269585, 0.18658121,  -0.03944227, 0.0111798,   -0.17731084, 0.11789055,  -0.09982193,
        0.08142821,  0.0729029,   0.11303909,  0.12735154,  0.03885292};

    std::vector<float> s = {-0.20817225,
                            0.87965256,
                            0.14958936,
                            -1.24887264,
                            -0.06540672,
                            0.20778663,
                            0.40456355,
                            -0.99900877,
                            0.4917807,
                            0.1994698,
                            0.64205718,
                            0.37798831,
                            -0.25315839,
                            0.44276932,
                            -0.16138598,
                            0.79344082};

    rtg::shape a_shape{rtg::shape::float_type, {2, 3, 4, 4}};
Scott Thornton's avatar
Scott Thornton committed
476
477
    auto al = p.add_literal(rtg::literal{a_shape, a});

478
    rtg::shape c_shape{rtg::shape::float_type, {2, 3, 3, 3}};
Scott Thornton's avatar
Scott Thornton committed
479
480
    auto cl = p.add_literal(rtg::literal{c_shape, c});

481
    p.add_instruction(rtg::convolution{{{1, 1}}, {{2, 2}}}, al, cl);
Scott Thornton's avatar
Scott Thornton committed
482
483
484
485
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(16);
486
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Scott Thornton's avatar
Scott Thornton committed
487
    EXPECT(test::verify_range(results_vector, s));
Scott Thornton's avatar
Scott Thornton committed
488
}
489

490
491
void transpose_test()
{
Paul's avatar
Paul committed
492
    rtg::shape a_shape{rtg::shape::float_type, {1, 2, 2, 3}};
493
494
495
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

496
497
    {
        rtg::program p;
Paul's avatar
Paul committed
498
499
        auto l                    = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> perm = {0, 3, 1, 2};
500
501
502
        p.add_instruction(rtg::transpose{perm}, l);
        p.compile(rtg::cpu::cpu_target{});
        auto result = p.eval({});
503

Paul's avatar
Paul committed
504
505
506
507
508
509
        result.visit([&](auto output) {
            std::vector<size_t> new_lens    = {1, 3, 2, 2};
            std::vector<size_t> new_strides = {12, 1, 6, 3};
            EXPECT(bool{output.get_shape().lens() == new_lens});
            EXPECT(bool{output.get_shape().strides() == new_strides});
        });
510
511
512
    }
    {
        rtg::program p;
Paul's avatar
Paul committed
513
514
515
        auto l                    = p.add_literal(rtg::literal{a_shape, data});
        std::vector<int64_t> perm = {0, 3, 1, 2};
        auto result               = p.add_instruction(rtg::transpose{perm}, l);
516
517
518
519
520
521
522
523
524
        p.add_instruction(rtg::contiguous{}, result);
        p.compile(rtg::cpu::cpu_target{});
        auto result2 = p.eval({});

        std::vector<float> results_vector(12);
        result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
        std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
        EXPECT(test::verify_range(results_vector, gold));
    }
525
526
}

Paul's avatar
Paul committed
527
528
529
void contiguous_test()
{
    rtg::shape a_shape{rtg::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
530
531
532
533
534
535
536
537
538
539
    std::vector<float> data(12);
    std::iota(data.begin(), data.end(), 0);

    rtg::program p;
    auto l = p.add_literal(rtg::literal{a_shape, data});
    p.add_instruction(rtg::contiguous{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});

    std::vector<float> results_vector(12);
540
    result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
Paul's avatar
Paul committed
541
542
543
    std::vector<size_t> new_lens    = {1, 3, 2, 2};
    std::vector<size_t> new_strides = {12, 1, 6, 3};
    std::vector<float> gold         = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
544
    EXPECT(test::verify_range(results_vector, gold));
545
546
}

547
548
549
int main()
{
    exp_test();
550
551
552
    sin_test();
    cos_test();
    tan_test();
553
    add_test();
554
555
    broadcast_test();
    add_broadcast_test();
556
557
    sub_test();
    mul_test();
558
    gemm_test();
559
    reshape_test();
560
    transpose_test();
561
    contiguous_test();
562
    softmax_test();
Scott Thornton's avatar
Scott Thornton committed
563
564
565
    conv2d_test();
    conv2d_padding_test();
    conv2d_padding_stride_test();
566
}