eval_test.cpp 5.81 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
#include <migraph/program.hpp>
Paul's avatar
Paul committed
3
4
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
#include "test.hpp"
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
struct id_target
{
Paul's avatar
Paul committed
11
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
12
    std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
Paul's avatar
Paul committed
13
    migraph::context get_context() const { return {}; }
Paul's avatar
Paul committed
14
15
};

Paul's avatar
Paul committed
16
17
struct reverse_pass
{
Paul's avatar
Paul committed
18
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
19
20
21

    void apply(migraph::program& p) const
    {
Paul's avatar
Paul committed
22
        for(auto ins : migraph::iterator_for(p))
Paul's avatar
Paul committed
23
        {
24
            if(ins->name() == "sum")
Paul's avatar
Paul committed
25
            {
26
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
27
            }
28
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
29
            {
30
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
            }
        }
    }
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
Paul's avatar
Paul committed
39
    std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; }
Paul's avatar
Paul committed
40
41
42
43
44
45
    migraph::context get_context() const { return {}; }
};

struct double_reverse_target
{
    std::string name() const { return "double_reverse"; }
Paul's avatar
Paul committed
46
47
48
    std::vector<migraph::pass> get_passes(migraph::context&) const
    {
        return {reverse_pass{}, reverse_pass{}};
Paul's avatar
Paul committed
49
50
51
52
    }
    migraph::context get_context() const { return {}; }
};

Paul's avatar
Paul committed
53
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
54
{
Paul's avatar
Paul committed
55
    migraph::program p;
Paul's avatar
Paul committed
56
57
58

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
59
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
60
    auto result = p.eval({});
Paul's avatar
Paul committed
61
62
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
63
64
}

Paul's avatar
Paul committed
65
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
66
{
Paul's avatar
Paul committed
67
    migraph::program p;
Paul's avatar
Paul committed
68

Paul's avatar
Paul committed
69
70
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
71
72
73
74
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
75
76
    EXPECT(result == migraph::literal{5});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
77
78
}

Paul's avatar
Paul committed
79
TEST_CASE(print_test)
Paul's avatar
Paul committed
80
{
Paul's avatar
Paul committed
81
    migraph::program p;
Paul's avatar
Paul committed
82

Paul's avatar
Paul committed
83
    auto x   = p.add_parameter("x", {migraph::shape::int64_type});
Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
92
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
93
TEST_CASE(param_test)
Paul's avatar
Paul committed
94
{
Paul's avatar
Paul committed
95
    migraph::program p;
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
98
    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});
Paul's avatar
Paul committed
99

Paul's avatar
Paul committed
100
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
101
102
    auto result = p.eval(
        {{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}});
Paul's avatar
Paul committed
103
104
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
105
106
}

Paul's avatar
Paul committed
107
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
108
109
110
111
112
113
114
115
116
117
118
{
    migraph::program p;

    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});

    p.add_instruction(sum_op{}, x, y);
    EXPECT(test::throws<migraph::exception>(
        [&] {
            p.eval({{"x", migraph::literal{1}.get_argument()}});
        },
119
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
120
121
}

Paul's avatar
Paul committed
122
TEST_CASE(replace_test)
Paul's avatar
Paul committed
123
{
Paul's avatar
Paul committed
124
    migraph::program p;
Paul's avatar
Paul committed
125
126
127
128
129

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
130
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
131
132

    auto result = p.eval({});
Paul's avatar
Paul committed
133
134
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
135
136
}

Paul's avatar
Paul committed
137
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
138
139
140
{
    migraph::program p;

Paul's avatar
Paul committed
141
142
143
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
144
145
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
146
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
147
148
149
150
151
152

    auto result = p.eval({});
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
}

Paul's avatar
Paul committed
153
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
154
155
156
{
    migraph::program p;

Paul's avatar
Paul committed
157
158
159
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
160
161
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
162
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
163
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
164
165
166
167
168
169

    auto result = p.eval({});
    EXPECT(result == migraph::literal{2});
    EXPECT(result != migraph::literal{3});
}

Paul's avatar
Paul committed
170
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
171
{
Paul's avatar
Paul committed
172
    migraph::program p;
Paul's avatar
Paul committed
173

Paul's avatar
Paul committed
174
175
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
176
177
178
179
180
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
181
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
182
183

    auto result = p.eval({});
Paul's avatar
Paul committed
184
185
    EXPECT(result == migraph::literal{4});
    EXPECT(result != migraph::literal{5});
Paul's avatar
Paul committed
186
187
}

Paul's avatar
Paul committed
188
TEST_CASE(target_test)
Paul's avatar
Paul committed
189
{
Paul's avatar
Paul committed
190
    migraph::program p;
Paul's avatar
Paul committed
191
192
193
194
195
196

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
197
198
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
199
200
}

Paul's avatar
Paul committed
201
TEST_CASE(reverse_target_test)
Paul's avatar
Paul committed
202
203
204
205
206
207
208
209
210
211
212
213
{
    migraph::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(reverse_target{});
    auto result = p.eval({});
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{4});
}

Paul's avatar
Paul committed
214
TEST_CASE(double_reverse_target_test)
Paul's avatar
Paul committed
215
216
217
218
219
220
221
222
223
224
225
226
{
    migraph::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(double_reverse_target{});
    auto result = p.eval({});
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
}

Paul's avatar
Paul committed
227
int main(int argc, const char* argv[]) { test::run(argc, argv); }