eval_test.cpp 8.32 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
#include "test.hpp"
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
struct id_target
{
Paul's avatar
Paul committed
11
12
13
14
15
    struct context
    {
        void finish() const {}
    };
    migraphx::context ctx = context{};
Paul's avatar
Paul committed
16
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
17
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
Paul's avatar
Paul committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    migraphx::context get_context() const { return ctx; }
};

struct id_ctx_op
{
    std::string name() const { return "id_ctx_op"; }
    migraphx::argument
    compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct id_ctx_final_op
{
    std::string name() const { return "id_ctx_final_op"; }
Paul's avatar
Paul committed
44
    migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
45
46
47
48
49
50
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
51
    void finalize(id_target::context&, const migraphx::shape&, std::vector<migraphx::shape>) {}
Paul's avatar
Paul committed
52
53
54
55
56
57
58
59

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
60
61
};

Paul's avatar
Paul committed
62
63
struct reverse_pass
{
Paul's avatar
Paul committed
64
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
    void apply(migraphx::program& p) const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
        for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
69
        {
70
            if(ins->name() == "sum")
Paul's avatar
Paul committed
71
            {
72
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
73
            }
74
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
75
            {
76
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
77
78
79
80
81
82
83
84
            }
        }
    }
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
Paul's avatar
Paul committed
85
86
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
87
88
89
90
91
};

struct double_reverse_target
{
    std::string name() const { return "double_reverse"; }
Paul's avatar
Paul committed
92
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
93
94
    {
        return {reverse_pass{}, reverse_pass{}};
Paul's avatar
Paul committed
95
    }
Paul's avatar
Paul committed
96
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
97
98
};

Paul's avatar
Paul committed
99
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
100
{
Paul's avatar
Paul committed
101
    migraphx::program p;
Paul's avatar
Paul committed
102
103
104

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
105
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
106
    auto result = p.eval({});
Paul's avatar
Paul committed
107
108
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
109
110
}

Paul's avatar
Paul committed
111
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
112
{
Paul's avatar
Paul committed
113
    migraphx::program p;
Paul's avatar
Paul committed
114

Paul's avatar
Paul committed
115
116
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
117
118
119
120
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
121
122
    EXPECT(result == migraphx::literal{5});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
123
124
}

Paul's avatar
Paul committed
125
TEST_CASE(print_test)
Paul's avatar
Paul committed
126
{
Paul's avatar
Paul committed
127
    migraphx::program p;
Paul's avatar
Paul committed
128

Paul's avatar
Paul committed
129
    auto x   = p.add_parameter("x", {migraphx::shape::int64_type});
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
139
TEST_CASE(param_test)
Paul's avatar
Paul committed
140
{
Paul's avatar
Paul committed
141
    migraphx::program p;
Paul's avatar
Paul committed
142

Paul's avatar
Paul committed
143
144
    auto x = p.add_parameter("x", {migraphx::shape::int64_type});
    auto y = p.add_parameter("y", {migraphx::shape::int64_type});
Paul's avatar
Paul committed
145

Paul's avatar
Paul committed
146
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
147
    auto result = p.eval(
Paul's avatar
Paul committed
148
149
150
        {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
151
152
}

Paul's avatar
Paul committed
153
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
154
{
Paul's avatar
Paul committed
155
    migraphx::program p;
Khalique's avatar
Khalique committed
156

Paul's avatar
Paul committed
157
158
    auto x = p.add_parameter("x", {migraphx::shape::int64_type});
    auto y = p.add_parameter("y", {migraphx::shape::int64_type});
Khalique's avatar
Khalique committed
159
160

    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
161
    EXPECT(test::throws<migraphx::exception>(
Khalique's avatar
Khalique committed
162
        [&] {
Paul's avatar
Paul committed
163
            p.eval({{"x", migraphx::literal{1}.get_argument()}});
Khalique's avatar
Khalique committed
164
        },
165
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
166
167
}

Paul's avatar
Paul committed
168
TEST_CASE(replace_test)
Paul's avatar
Paul committed
169
{
Paul's avatar
Paul committed
170
    migraphx::program p;
Paul's avatar
Paul committed
171
172
173
174
175

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
176
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
177
178

    auto result = p.eval({});
Paul's avatar
Paul committed
179
180
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
181
182
}

Paul's avatar
Paul committed
183
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
184
{
Paul's avatar
Paul committed
185
    migraphx::program p;
Paul's avatar
Paul committed
186

Paul's avatar
Paul committed
187
188
189
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
190
191
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
192
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
193
194

    auto result = p.eval({});
Paul's avatar
Paul committed
195
196
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
197
198
}

Paul's avatar
Paul committed
199
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
200
{
Paul's avatar
Paul committed
201
    migraphx::program p;
Paul's avatar
Paul committed
202

Paul's avatar
Paul committed
203
204
205
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
206
207
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
208
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
209
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
210
211

    auto result = p.eval({});
Paul's avatar
Paul committed
212
213
    EXPECT(result == migraphx::literal{2});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
214
215
}

Paul's avatar
Paul committed
216
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
217
{
Paul's avatar
Paul committed
218
    migraphx::program p;
Paul's avatar
Paul committed
219

Paul's avatar
Paul committed
220
221
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
222
223
224
225
226
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
227
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
228
229

    auto result = p.eval({});
Paul's avatar
Paul committed
230
231
    EXPECT(result == migraphx::literal{4});
    EXPECT(result != migraphx::literal{5});
Paul's avatar
Paul committed
232
233
}

Paul's avatar
Paul committed
234
TEST_CASE(target_test)
Paul's avatar
Paul committed
235
{
Paul's avatar
Paul committed
236
    migraphx::program p;
Paul's avatar
Paul committed
237
238
239
240
241
242

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
243
244
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
245
246
}

Paul's avatar
Paul committed
247
TEST_CASE(reverse_target_test)
Paul's avatar
Paul committed
248
{
Paul's avatar
Paul committed
249
    migraphx::program p;
Paul's avatar
Paul committed
250
251
252
253
254
255

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(reverse_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
256
257
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
258
259
}

Paul's avatar
Paul committed
260
TEST_CASE(double_reverse_target_test)
Paul's avatar
Paul committed
261
{
Paul's avatar
Paul committed
262
    migraphx::program p;
Paul's avatar
Paul committed
263
264
265
266
267
268

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(double_reverse_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
269
270
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
271
272
}

Paul's avatar
Paul committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
TEST_CASE(eval_context1)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    EXPECT(is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context2)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    // id_ctx_op will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context3)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_final_op{}, one, two);
    p.compile(t);
    // Finalizer will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
    auto ctx = p.get_context();
    p.eval({});
    EXPECT(is_shared(ctx, p.get_context()));
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

Paul's avatar
Paul committed
319
int main(int argc, const char* argv[]) { test::run(argc, argv); }