cpu_ops_test.cpp 8.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cassert>
#include <iostream>
#include <vector>
#include <rtg/literal.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>

using rtg::shape;
using rtg::argument;

void exp_test() {
    rtg::program p;
    rtg::shape s{rtg::shape::float_type, {3}};
    auto l = p.add_literal(rtg::literal{s, {-1,0,1}});
    p.add_instruction(rtg::exp{}, l);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(3);
    memcpy(results_vector.data(), result.data(), 3*sizeof(float));
    std::vector<float> gold = {0.36787944f,1.f,2.71828183f};
    float tol = 1e-8;
    for (int i = 0; i < results_vector.size(); i++) {
        assert(std::abs(results_vector[i]-gold[i]) < tol);
    }
}

void gemm_test() {
    rtg::program p;
    std::vector<float> A = {-0.00925222,  0.56250403,  0.70107397,  0.75402161, -0.505885  ,
                             1.33628943, -0.11413   , -0.31270559,  1.59336732, -0.19361027,
                            -0.91620867,  0.40108416, -0.06969921,  0.68483471, -0.39906632,
                            -1.66423624,  0.69040076, -1.31490171, -0.11282616, -0.79391814};
    std::vector<float> B = { 6.09568541e-01,  -6.10527007e-01,   3.66646462e-01,
                             1.18951101e-01,   5.58777432e-01,  -3.21296298e-01,
                            -5.95997198e-01,  -5.01425721e-01,  -2.84606807e-01,
                            -5.73673557e-01,  -8.99430260e-01,  -4.25103093e-01,
                             1.53027987e+00,  -3.81407415e-04,  -3.29650255e-01};
    std::vector<float> C = {-1.56327541e+00,  -7.09570140e-01,  -5.37424982e-01,
                            -2.22994831e-01,  -2.15586437e+00,   2.09177941e-03,
                            -1.47279677e+00,   2.02627040e-01,  -6.04527691e-01,
                            -1.29885596e+00,   2.16294914e+00,  -1.48101497e-01};
    rtg::shape a_shape{rtg::shape::float_type, {4,5}};
    auto a = p.add_literal(rtg::literal{a_shape, A});
    rtg::shape b_shape{rtg::shape::float_type, {5,3}};
    auto b = p.add_literal(rtg::literal{b_shape, B});
    p.add_instruction(rtg::gemm{}, a, b);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(12);
    memcpy(results_vector.data(), result.data(), 12*sizeof(float));
    float tol = 1e-6;
    for (int i = 0; i < results_vector.size(); i++) {
        assert(std::abs(results_vector[i]-C[i]) < tol);
    }
}

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
void softmax_test() {
    rtg::program p;
    std::vector<float> A = {-5.61869681e-01,   9.07827199e-01,   1.29255986e+00,
                             3.18533443e-02,  -1.22183852e-03,  -2.83830553e-01,
                            -1.03245842e+00,  -9.28322077e-01,  -8.82696748e-01,
                             1.11327164e-01,  -9.20038462e-01,   8.47388089e-01,
                             2.51734018e-01,   1.50563884e+00,   2.23056650e+00,
                            -6.17576987e-02,  -1.00264274e-01,  -6.10369384e-01,
                             1.17537189e+00,  -2.51560897e-01,  -8.50333512e-01,
                            -8.03578615e-01,  -6.51194930e-01,  -2.58137047e-01,
                             4.65528190e-01,   3.23284641e-02,  -1.54700470e+00,
                             1.38096774e+00,   5.39869189e-01,  -7.56884992e-01,
                             1.81503093e+00,  -2.11269641e+00,   1.92466557e+00,
                             1.77230799e+00,   2.21660900e+00,   1.56777036e+00,
                            -2.08995026e-03,   3.50566894e-01,  -1.15042710e+00,
                            -1.18577778e+00,   8.90633047e-01,  -6.63949102e-02,
                             1.44661188e+00,   1.59215283e+00,  -2.56262213e-01,
                             9.39079225e-01,   4.07298543e-02,   3.86590779e-01,
                             6.09607756e-01,   8.22331488e-01,  -2.82126725e-01,
                            -9.49052632e-01,  -4.24012303e-01,  -5.32990396e-01,
                            -3.18386006e+00,   3.27092171e-01,  -1.33315325e+00,
                             3.62459183e-01,   3.74710828e-01,  -1.30302286e+00,
                             1.79680198e-01,  -4.51832324e-01,   4.34282750e-01,
                            -7.09520102e-01,   6.20333970e-01,  -1.28712380e+00,
                             2.04130828e-01,  -7.70607769e-01,   1.61889160e+00,
                            -1.50951004e+00,  -4.10505563e-01,  -3.56566496e-02,
                            -1.29747534e+00,  -1.49967879e-01,   7.77626812e-01,
                            -8.28408226e-02,   2.73412596e-02,   5.79780899e-03,
                             9.87900198e-02,  -7.95276761e-01,  -1.38536084e+00,
                            -6.63573861e-01,   3.89783204e-01,  -1.30670881e+00,
                            -7.62425125e-01,  -4.04883057e-01,   6.24344349e-01,
                             3.68128955e-01,  -1.01577950e+00,  -3.06715906e-01,
                             5.67961395e-01,   2.98198581e-01,  -1.63613629e+00,
                            -3.75131965e-01,  -6.75393403e-01,   2.59172034e+00,
                             6.75538957e-01,   9.07939598e-02,   1.92257717e-01,
                            -1.21592450e+00,  -2.73682117e-01,   1.25232983e+00,
                            -1.39969170e+00,  -1.91483587e-01,   2.57732719e-01,
                             3.10056299e-01,   1.41833842e+00,  -1.81386679e-01,
                             3.92868072e-01,  -8.14771175e-01,   2.02392387e+00,
                            -9.42091495e-02,  -3.77683818e-01,   2.05638766e+00,
                             2.93796062e-01,  -6.02131486e-01,   2.70461679e-01,
                            -8.92358482e-01,   1.04388881e+00,   2.66154885e-01};

    std::vector<float> S = {0.30191708,  0.59879845,  0.50029165,  0.24915339,  0.36823985,
                            0.13190967,  0.0349741 ,  0.18750034,  0.21905553,  0.27000085,
                            0.0547399 ,  0.56318235,  0.47422904,  0.78964758,  0.91381913,
                            0.44601166,  0.47902739,  0.13120073,  0.4449684 ,  0.18766427,
                            0.15753111,  0.07844277,  0.05120674,  0.36648798,  0.14637007,
                            0.13152322,  0.01560997,  0.29065287,  0.49196178,  0.10550152,
                            0.81890774,  0.06369215,  0.62972021,  0.74931765,  0.67285055,
                            0.35034987,  0.28612873,  0.31931475,  0.04220394,  0.16093165,
                            0.22390974,  0.11915915,  0.3115395 ,  0.35899726,  0.22190949,
                            0.57518375,  0.13888834,  0.7753762 ,  0.4642328 ,  0.57055861,
                            0.21954368,  0.34515455,  0.09486015,  0.40631217,  0.01842281,
                            0.48770609,  0.06652815,  0.36023033,  0.42343026,  0.24226256,
                            0.17348589,  0.44066274,  0.6865865 ,  0.17296699,  0.46923906,
                            0.06921105,  0.3570261 ,  0.4125829 ,  0.73165393,  0.15302512,
                            0.29499072,  0.33932695,  0.30852377,  0.40762195,  0.40170741,
                            0.36259529,  0.60848355,  0.42618036,  0.31721094,  0.02960522,
                            0.28256637,  0.24389413,  0.2725659 ,  0.10663581,  0.27622163,
                            0.28264219,  0.53652936,  0.09476089,  0.40890986,  0.34848392,
                            0.32572666,  0.53076893,  0.11529481,  0.29117745,  0.14625968,
                            0.8756339 ,  0.49818122,  0.10656087,  0.1813329 ,  0.17664003,
                            0.21410346,  0.80408043,  0.02315119,  0.27155462,  0.32804728,
                            0.13268511,  0.61795473,  0.49703068,  0.41696799,  0.10175809,
                            0.71028161,  0.29929739,  0.17377149,  0.76075399,  0.20071237,
                            0.32632929,  0.36892858,  0.09416146,  0.26656723,  0.42914796};

    rtg::shape a_shape{rtg::shape::float_type, {5,3,4,2}};
    auto a = p.add_literal(rtg::literal{a_shape, A});
    p.add_instruction(rtg::softmax{}, a);
    p.compile(rtg::cpu::cpu_target{});
    auto result = p.eval({});
    std::vector<float> results_vector(120);
    memcpy(results_vector.data(), result.data(), 120*sizeof(float));
    float tol = 1e-6;
    for (int i = 0; i < results_vector.size(); i++) {
        assert(std::abs(results_vector[i]-S[i]) < tol);
    }
}

138
139
140
141
int main()
{
    exp_test();
    gemm_test();
142
    softmax_test();
143
}