eval_test.cpp 4.43 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#include <rtg/program.hpp>
#include <rtg/argument.hpp>
#include <rtg/shape.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
7
#include "test.hpp"

Paul's avatar
Paul committed
8
9
struct sum_op
{
Paul's avatar
Paul committed
10
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
11
    rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
Paul's avatar
Paul committed
12
13
    {
        rtg::argument result;
Paul's avatar
Paul committed
14
        if(args.size() != 2)
Paul's avatar
Paul committed
15
            RTG_THROW("Wrong args");
Paul's avatar
Paul committed
16
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
17
            RTG_THROW("Wrong args");
Paul's avatar
Paul committed
18
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
19
            RTG_THROW("Wrong args");
Paul's avatar
Paul committed
20
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
21
            RTG_THROW("Wrong args");
Paul's avatar
Paul committed
22
23

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
24
            args[1].visit_at([&](auto y) { result = rtg::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
25
26
27
28
29
30
        });
        return result;
    }

    rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
    {
Paul's avatar
Paul committed
31
        if(inputs.size() != 2)
Paul's avatar
Paul committed
32
            RTG_THROW("Wrong inputs");
Paul's avatar
Paul committed
33
34
35
36
        return inputs.front();
    }
};

Paul's avatar
Paul committed
37
38
39
struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
40
    rtg::argument compute(rtg::context&, rtg::shape, std::vector<rtg::argument> args) const
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    {
        rtg::argument result;
        if(args.size() != 2)
            RTG_THROW("Wrong args");
        if(args[0].get_shape() != args[1].get_shape())
            RTG_THROW("Wrong args");
        if(args[0].get_shape().lens().size() != 1)
            RTG_THROW("Wrong args");
        if(args[0].get_shape().lens().front() != 1)
            RTG_THROW("Wrong args");

        args[0].visit_at([&](auto x) {
            args[1].visit_at([&](auto y) { result = rtg::literal{x - y}.get_argument(); });
        });
        return result;
    }

    rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
    {
        if(inputs.size() != 2)
            RTG_THROW("Wrong inputs");
        return inputs.front();
    }
};

Paul's avatar
Paul committed
66
67
struct id_target
{
Paul's avatar
Paul committed
68
69
    std::string name() const { return "id"; }
    void apply(rtg::program&) const {}
Paul's avatar
Paul committed
70
    rtg::context get_context() const { return {}; }
Paul's avatar
Paul committed
71
72
};

Paul's avatar
Paul committed
73
void literal_test1()
Paul's avatar
Paul committed
74
{
Paul's avatar
Paul committed
75
    rtg::program p;
Paul's avatar
Paul committed
76
77
78

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
79
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
80
81
82
83
84
    auto result = p.eval({});
    EXPECT(result == rtg::literal{3});
    EXPECT(result != rtg::literal{4});
}

Paul's avatar
Paul committed
85
86
87
88
void literal_test2()
{
    rtg::program p;

Paul's avatar
Paul committed
89
90
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
    EXPECT(result == rtg::literal{5});
    EXPECT(result != rtg::literal{3});
}

Paul's avatar
Paul committed
99
100
101
102
void print_test()
{
    rtg::program p;

Paul's avatar
Paul committed
103
    auto x   = p.add_parameter("x", {rtg::shape::int64_type});
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
112
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
113
114
void param_test()
{
Paul's avatar
Paul committed
115
116
    rtg::program p;

Paul's avatar
Paul committed
117
118
    auto x = p.add_parameter("x", {rtg::shape::int64_type});
    auto y = p.add_parameter("y", {rtg::shape::int64_type});
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
121
122
    auto result =
        p.eval({{"x", rtg::literal{1}.get_argument()}, {"y", rtg::literal{2}.get_argument()}});
Paul's avatar
Paul committed
123
124
125
126
    EXPECT(result == rtg::literal{3});
    EXPECT(result != rtg::literal{4});
}

Paul's avatar
Paul committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
void replace_test()
{
    rtg::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);

    auto result = p.eval({});
    EXPECT(result == rtg::literal{1});
    EXPECT(result != rtg::literal{3});
}

Paul's avatar
Paul committed
141
142
143
144
void insert_replace_test()
{
    rtg::program p;

Paul's avatar
Paul committed
145
146
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
147
148
149
150
151
152
153
154
155
156
157
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);

    auto result = p.eval({});
    EXPECT(result == rtg::literal{4});
    EXPECT(result != rtg::literal{5});
}

Paul's avatar
Paul committed
158
159
160
161
162
163
164
165
166
167
168
169
170
void target_test()
{
    rtg::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
    EXPECT(result == rtg::literal{3});
    EXPECT(result != rtg::literal{4});
}

Paul's avatar
Paul committed
171
172
int main()
{
Paul's avatar
Paul committed
173
174
    literal_test1();
    literal_test2();
Paul's avatar
Paul committed
175
    print_test();
Paul's avatar
Paul committed
176
    param_test();
Paul's avatar
Paul committed
177
    replace_test();
Paul's avatar
Paul committed
178
    insert_replace_test();
Paul's avatar
Paul committed
179
    target_test();
Paul's avatar
Paul committed
180
}