eval_test.cpp 6.44 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
Paul's avatar
Paul committed
5
6
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
Paul's avatar
Paul committed
7
#include <sstream>
Paul's avatar
Paul committed
8
9
#include "test.hpp"

Paul's avatar
Paul committed
10
11
struct sum_op
{
Paul's avatar
Paul committed
12
    std::string name() const { return "sum"; }
Paul's avatar
Paul committed
13
14
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
15
    {
Paul's avatar
Paul committed
16
        migraph::argument result;
Paul's avatar
Paul committed
17
        if(args.size() != 2)
Paul's avatar
Paul committed
18
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
19
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
20
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
21
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
22
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
23
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
24
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
25
26

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
27
            args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
Paul's avatar
Paul committed
28
29
30
31
        });
        return result;
    }

Paul's avatar
Paul committed
32
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
33
    {
Paul's avatar
Paul committed
34
        if(inputs.size() != 2)
Paul's avatar
Paul committed
35
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
36
37
38
39
        return inputs.front();
    }
};

Paul's avatar
Paul committed
40
41
42
struct minus_op
{
    std::string name() const { return "minus"; }
Paul's avatar
Paul committed
43
44
    migraph::argument
    compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
Paul's avatar
Paul committed
45
    {
Paul's avatar
Paul committed
46
        migraph::argument result;
Paul's avatar
Paul committed
47
        if(args.size() != 2)
Paul's avatar
Paul committed
48
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
49
        if(args[0].get_shape() != args[1].get_shape())
Paul's avatar
Paul committed
50
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
51
        if(args[0].get_shape().lens().size() != 1)
Paul's avatar
Paul committed
52
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
53
        if(args[0].get_shape().lens().front() != 1)
Paul's avatar
Paul committed
54
            MIGRAPH_THROW("Wrong args");
Paul's avatar
Paul committed
55
56

        args[0].visit_at([&](auto x) {
Paul's avatar
Paul committed
57
            args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
Paul's avatar
Paul committed
58
59
60
61
        });
        return result;
    }

Paul's avatar
Paul committed
62
    migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
Paul's avatar
Paul committed
63
64
    {
        if(inputs.size() != 2)
Paul's avatar
Paul committed
65
            MIGRAPH_THROW("Wrong inputs");
Paul's avatar
Paul committed
66
67
68
69
        return inputs.front();
    }
};

Paul's avatar
Paul committed
70
71
struct id_target
{
Paul's avatar
Paul committed
72
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
73
    std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
Paul's avatar
Paul committed
74
    migraph::context get_context() const { return {}; }
Paul's avatar
Paul committed
75
76
};

Paul's avatar
Paul committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
struct reverse_pass
{
    std::string name() const
    {
        return "reverse_pass";
    }

    void apply(migraph::program& p) const
    {
        for(auto ins:migraph::iterator_for(p))
        {
            if(ins->op.name() == "sum") 
            {
                p.replace_instruction(ins, minus_op{}, ins->arguments);
            }
            else if(ins->op.name() == "minus") 
            {
                p.replace_instruction(ins, sum_op{}, ins->arguments);
            }
        }
    }
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
    std::vector<migraph::pass> get_passes(migraph::context&) const 
    { 
        return { reverse_pass{} }; 
    }
    migraph::context get_context() const { return {}; }
};

struct double_reverse_target
{
    std::string name() const { return "double_reverse"; }
    std::vector<migraph::pass> get_passes(migraph::context&) const 
    { 
        return { reverse_pass{}, reverse_pass{} }; 
    }
    migraph::context get_context() const { return {}; }
};

Paul's avatar
Paul committed
120
void literal_test1()
Paul's avatar
Paul committed
121
{
Paul's avatar
Paul committed
122
    migraph::program p;
Paul's avatar
Paul committed
123
124
125

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
126
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
127
    auto result = p.eval({});
Paul's avatar
Paul committed
128
129
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
130
131
}

Paul's avatar
Paul committed
132
133
void literal_test2()
{
Paul's avatar
Paul committed
134
    migraph::program p;
Paul's avatar
Paul committed
135

Paul's avatar
Paul committed
136
137
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
138
139
140
141
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
142
143
    EXPECT(result == migraph::literal{5});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
144
145
}

Paul's avatar
Paul committed
146
147
void print_test()
{
Paul's avatar
Paul committed
148
    migraph::program p;
Paul's avatar
Paul committed
149

Paul's avatar
Paul committed
150
    auto x   = p.add_parameter("x", {migraph::shape::int64_type});
Paul's avatar
Paul committed
151
152
153
154
155
156
157
158
159
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
160
161
void param_test()
{
Paul's avatar
Paul committed
162
    migraph::program p;
Paul's avatar
Paul committed
163

Paul's avatar
Paul committed
164
165
    auto x = p.add_parameter("x", {migraph::shape::int64_type});
    auto y = p.add_parameter("y", {migraph::shape::int64_type});
Paul's avatar
Paul committed
166

Paul's avatar
Paul committed
167
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
168
169
    auto result = p.eval(
        {{"x", migraph::literal{1}.get_argument()}, {"y", migraph::literal{2}.get_argument()}});
Paul's avatar
Paul committed
170
171
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
172
173
}

Paul's avatar
Paul committed
174
175
void replace_test()
{
Paul's avatar
Paul committed
176
    migraph::program p;
Paul's avatar
Paul committed
177
178
179
180
181
182
183

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);

    auto result = p.eval({});
Paul's avatar
Paul committed
184
185
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{3});
Paul's avatar
Paul committed
186
187
}

Paul's avatar
Paul committed
188
189
void insert_replace_test()
{
Paul's avatar
Paul committed
190
    migraph::program p;
Paul's avatar
Paul committed
191

Paul's avatar
Paul committed
192
193
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
194
195
196
197
198
199
200
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
201
202
    EXPECT(result == migraph::literal{4});
    EXPECT(result != migraph::literal{5});
Paul's avatar
Paul committed
203
204
}

Paul's avatar
Paul committed
205
206
void target_test()
{
Paul's avatar
Paul committed
207
    migraph::program p;
Paul's avatar
Paul committed
208
209
210
211
212
213

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
214
215
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
Paul's avatar
Paul committed
216
217
}

Paul's avatar
Paul committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
void reverse_target_test()
{
    migraph::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(reverse_target{});
    auto result = p.eval({});
    EXPECT(result == migraph::literal{1});
    EXPECT(result != migraph::literal{4});
}

void double_reverse_target_test()
{
    migraph::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(double_reverse_target{});
    auto result = p.eval({});
    EXPECT(result == migraph::literal{3});
    EXPECT(result != migraph::literal{4});
}

Paul's avatar
Paul committed
244
245
int main()
{
Paul's avatar
Paul committed
246
247
    literal_test1();
    literal_test2();
Paul's avatar
Paul committed
248
    print_test();
Paul's avatar
Paul committed
249
    param_test();
Paul's avatar
Paul committed
250
    replace_test();
Paul's avatar
Paul committed
251
    insert_replace_test();
Paul's avatar
Paul committed
252
    target_test();
Paul's avatar
Paul committed
253
    reverse_target_test();
Paul's avatar
Paul committed
254
}