eval_test.cpp 13.9 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/stringutils.hpp>
6
#include <migraphx/compile_options.hpp>
Paul's avatar
Paul committed
7
#include <sstream>
Paul's avatar
Paul committed
8
#include "test.hpp"
Paul's avatar
Paul committed
9
#include <basic_ops.hpp>
Paul's avatar
Paul committed
10

Paul's avatar
Paul committed
11
12
struct id_target
{
Paul's avatar
Paul committed
13
14
15
16
17
    struct context
    {
        void finish() const {}
    };
    migraphx::context ctx = context{};
Paul's avatar
Paul committed
18
    std::string name() const { return "id"; }
19
20
21
22
23
    std::vector<migraphx::pass> get_passes(migraphx::context&,
                                           const migraphx::compile_options&) const
    {
        return {};
    }
Paul's avatar
Paul committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    migraphx::context get_context() const { return ctx; }
};

struct id_ctx_op
{
    std::string name() const { return "id_ctx_op"; }
    migraphx::argument
    compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct id_ctx_final_op
{
    std::string name() const { return "id_ctx_final_op"; }
Paul's avatar
Paul committed
50
    migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
51
52
53
54
55
56
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
57
58
59
    void finalize(id_target::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
    {
    }
Paul's avatar
Paul committed
60
61
62
63
64
65
66
67

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
68
69
};

Paul's avatar
Paul committed
70
71
struct reverse_pass
{
Paul's avatar
Paul committed
72
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
73

Paul's avatar
Paul committed
74
    void apply(migraphx::program& p) const { std::reverse(p.begin(), p.end()); }
Paul's avatar
Paul committed
75
76
77
78
79
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
80
81
82
83
84
    std::vector<migraphx::pass> get_passes(migraphx::context&,
                                           const migraphx::compile_options&) const
    {
        return {reverse_pass{}};
    }
Paul's avatar
Paul committed
85
86
87
88
89
90
91
    migraphx::context get_context() const { return {}; }
};

struct invert_pass
{
    std::string name() const { return "invert_pass"; }

Paul's avatar
Paul committed
92
    void apply(migraphx::program& p) const
Paul's avatar
Paul committed
93
    {
Paul's avatar
Paul committed
94
        for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
95
        {
96
            if(ins->name() == "sum")
Paul's avatar
Paul committed
97
            {
98
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
99
            }
100
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
101
            {
102
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
103
104
105
106
107
            }
        }
    }
};

Paul's avatar
Paul committed
108
struct invert_target
Paul's avatar
Paul committed
109
{
Paul's avatar
Paul committed
110
    std::string name() const { return "invert"; }
111
112
113
114
115
    std::vector<migraphx::pass> get_passes(migraphx::context&,
                                           const migraphx::compile_options&) const
    {
        return {invert_pass{}};
    }
Paul's avatar
Paul committed
116
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
117
118
};

Paul's avatar
Paul committed
119
struct double_invert_target
Paul's avatar
Paul committed
120
{
Paul's avatar
Paul committed
121
    std::string name() const { return "double_invert"; }
122
123
    std::vector<migraphx::pass> get_passes(migraphx::context&,
                                           const migraphx::compile_options&) const
Paul's avatar
Paul committed
124
    {
Paul's avatar
Paul committed
125
        return {invert_pass{}, invert_pass{}};
Paul's avatar
Paul committed
126
    }
Paul's avatar
Paul committed
127
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
128
129
};

Paul's avatar
Paul committed
130
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
131
{
Paul's avatar
Paul committed
132
    migraphx::program p;
Paul's avatar
Paul committed
133
134
135

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
136
    p.add_instruction(sum_op{}, one, two);
137
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
138
139
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
140
141
}

Paul's avatar
Paul committed
142
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
143
{
Paul's avatar
Paul committed
144
    migraphx::program p;
Paul's avatar
Paul committed
145

Paul's avatar
Paul committed
146
147
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
148
149
150
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

151
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
152
153
    EXPECT(result == migraphx::literal{5});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
154
155
}

Paul's avatar
Paul committed
156
TEST_CASE(print_test)
Paul's avatar
Paul committed
157
{
Paul's avatar
Paul committed
158
    migraphx::program p;
Paul's avatar
Paul committed
159

Paul's avatar
Paul committed
160
    auto x   = p.add_parameter("x", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
161
162
163
164
165
166
167
168
169
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
170
TEST_CASE(param_test)
Paul's avatar
Paul committed
171
{
Paul's avatar
Paul committed
172
    migraphx::program p;
Paul's avatar
Paul committed
173

Paul's avatar
Paul committed
174
175
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
176

Paul's avatar
Paul committed
177
    p.add_instruction(sum_op{}, x, y);
178
179
180
    auto result = p.eval({{"x", migraphx::literal{1}.get_argument()},
                          {"y", migraphx::literal{2}.get_argument()}})
                      .back();
Paul's avatar
Paul committed
181
182
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
183
184
}

Paul's avatar
Paul committed
185
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
186
{
Paul's avatar
Paul committed
187
    migraphx::program p;
Khalique's avatar
Khalique committed
188

Paul's avatar
Paul committed
189
190
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Khalique's avatar
Khalique committed
191
192

    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
193
    EXPECT(test::throws<migraphx::exception>(
Khalique's avatar
Khalique committed
194
        [&] {
Paul's avatar
Paul committed
195
            p.eval({{"x", migraphx::literal{1}.get_argument()}});
Khalique's avatar
Khalique committed
196
        },
197
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
198
199
}

Paul's avatar
Paul committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
TEST_CASE(param_error_shape_test)
{
    migraphx::program p;

    auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});

    p.add_instruction(sum_op{}, x, y);
    EXPECT(test::throws<migraphx::exception>(
        [&] {
            p.eval({
                {"x", migraphx::literal{1}.get_argument()},
                {"y", migraphx::literal{{migraphx::shape::int32_type, {1, 1}}, {2}}.get_argument()},
            });
        },
        "Incorrect shape {int32_type, {1}, {0}} for parameter: x"));
}

Paul's avatar
Paul committed
218
TEST_CASE(get_param1)
Paul's avatar
Paul committed
219
220
{
    migraphx::program p;
Paul's avatar
Paul committed
221
222
223
224
225
226
227
228
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    p.add_instruction(sum_op{}, x, y);
    EXPECT(bool{p.get_parameter("x") == x});
    EXPECT(bool{p.get_parameter("y") == y});
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
229

Paul's avatar
Paul committed
230
231
232
233
234
235
236
237
TEST_CASE(get_param2)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
238

Paul's avatar
Paul committed
239
240
241
242
243
244
TEST_CASE(get_param_shapes)
{
    migraphx::program p;
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
Paul's avatar
Paul committed
245
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
246
247
248
249
    auto m = p.get_parameter_shapes();
    EXPECT(m.count("nonexistent") == 0);
    EXPECT(m.at("x") == s);
    EXPECT(m.at("y") == s);
Paul's avatar
Paul committed
250
251
}

Paul's avatar
Paul committed
252
TEST_CASE(replace_test)
Paul's avatar
Paul committed
253
{
Paul's avatar
Paul committed
254
    migraphx::program p;
Paul's avatar
Paul committed
255
256
257
258
259

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
260
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
261

262
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
263
264
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
265
266
}

Paul's avatar
Paul committed
267
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
268
{
Paul's avatar
Paul committed
269
    migraphx::program p;
Paul's avatar
Paul committed
270

Paul's avatar
Paul committed
271
272
273
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
274
275
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
276
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
277

278
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
279
280
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
281
282
}

Paul's avatar
Paul committed
283
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
284
{
Paul's avatar
Paul committed
285
    migraphx::program p;
Paul's avatar
Paul committed
286

Paul's avatar
Paul committed
287
288
289
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
290
291
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
292
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
293
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
294

295
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
296
297
    EXPECT(result == migraphx::literal{2});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
298
299
}

Paul's avatar
Paul committed
300
301
302
303
304
305
306
307
308
309
TEST_CASE(replace_op_test)
{
    migraphx::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, two, one);
    sum->replace(minus_op{});
    EXPECT(bool{p.validate() == p.end()});

310
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
311
312
313
314
315
316
317
318
319
320
321
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
}

TEST_CASE(replace_op_recompute_shape_throw)
{
    migraphx::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
322
    EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
Paul's avatar
Paul committed
323
324
}

Paul's avatar
Paul committed
325
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
326
{
Paul's avatar
Paul committed
327
    migraphx::program p;
Paul's avatar
Paul committed
328

Paul's avatar
Paul committed
329
330
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
331
332
333
334
335
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
336
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
337

338
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
339
340
    EXPECT(result == migraphx::literal{4});
    EXPECT(result != migraphx::literal{5});
Paul's avatar
Paul committed
341
342
}

Paul's avatar
Paul committed
343
344
345
346
TEST_CASE(remove_test1)
{
    migraphx::program p;

Paul's avatar
Paul committed
347
348
349
    auto one     = p.add_literal(1);
    auto two     = p.add_literal(2);
    auto sum     = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
350
351
352
353
    auto removed = p.add_instruction(minus_op{}, sum, one);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});

354
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
355
356
357
358
359
360
361
362
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

TEST_CASE(remove_test2)
{
    migraphx::program p;

Paul's avatar
Paul committed
363
364
    auto one     = p.add_literal(1);
    auto two     = p.add_literal(2);
Paul's avatar
Paul committed
365
366
367
368
369
    auto removed = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(sum_op{}, one, two);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});

370
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
371
372
373
374
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

Paul's avatar
Paul committed
375
TEST_CASE(target_test)
Paul's avatar
Paul committed
376
{
Paul's avatar
Paul committed
377
    migraphx::program p;
Paul's avatar
Paul committed
378
379
380
381
382

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
383
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
384
385
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
386
387
}

Paul's avatar
Paul committed
388
TEST_CASE(invert_target_test)
Paul's avatar
Paul committed
389
{
Paul's avatar
Paul committed
390
    migraphx::program p;
Paul's avatar
Paul committed
391
392
393
394

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
395
    p.compile(invert_target{});
396
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
397
398
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
399
400
}

Paul's avatar
Paul committed
401
TEST_CASE(double_invert_target_test)
Paul's avatar
Paul committed
402
{
Paul's avatar
Paul committed
403
    migraphx::program p;
Paul's avatar
Paul committed
404
405
406
407

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
408
    p.compile(double_invert_target{});
409
    auto result = p.eval({}).back();
Paul's avatar
Paul committed
410
411
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
412
413
}

Paul's avatar
Paul committed
414
415
416
417
418
419
420
TEST_CASE(reverse_target_test)
{
    migraphx::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
421
    EXPECT(test::throws<migraphx::exception>([&] { p.compile(reverse_target{}); }));
Paul's avatar
Paul committed
422
423
}

Paul's avatar
Paul committed
424
425
// Check that the program doesnt modify the context directly, and only the operators modify the
// context
Paul's avatar
Paul committed
426
427
428
429
430
431
432
433
434
435
TEST_CASE(eval_context1)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
436
    p.eval({}).back();
Paul's avatar
Paul committed
437
438
439
440
441
442
443
444
445
446
447
448
449
    EXPECT(is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context2)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
450
    p.eval({}).back();
Paul's avatar
Paul committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    // id_ctx_op will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context3)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_final_op{}, one, two);
    p.compile(t);
    // Finalizer will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
    auto ctx = p.get_context();
467
    p.eval({}).back();
Paul's avatar
Paul committed
468
469
470
471
    EXPECT(is_shared(ctx, p.get_context()));
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

Paul's avatar
Paul committed
472
473
474
475
476
477
struct cout_redirect
{
    cout_redirect()                     = delete;
    cout_redirect(const cout_redirect&) = delete;
    template <class T>
    cout_redirect(T& stream) : old(std::cout.rdbuf(stream.rdbuf()))
Paul's avatar
Paul committed
478
479
    {
    }
Paul's avatar
Paul committed
480
    ~cout_redirect() { std::cout.rdbuf(old); }
Paul's avatar
Paul committed
481

Paul's avatar
Paul committed
482
483
    private:
    std::streambuf* old;
Paul's avatar
Paul committed
484
485
};

Paul's avatar
Paul committed
486
template <class F>
Paul's avatar
Paul committed
487
488
489
490
491
492
493
494
495
496
497
std::string capture_output(F f)
{
    std::stringstream ss;
    cout_redirect cr{ss};
    f();
    return ss.str();
}

TEST_CASE(debug_print_test)
{
    migraphx::program p;
Paul's avatar
Paul committed
498
    auto one                                    = p.add_literal(1);
Paul's avatar
Paul committed
499
    std::vector<migraphx::instruction_ref> onev = {one};
Paul's avatar
Paul committed
500
501
502
503

    migraphx::program p2;
    auto one2 = p2.add_literal(1);

Paul's avatar
Paul committed
504
505
    auto program_out = migraphx::trim(capture_output([&] { p.debug_print(); }));
    auto ins_out     = migraphx::trim(capture_output([&] { p.debug_print(one); }));
Paul's avatar
Paul committed
506
    auto inss_out    = migraphx::trim(capture_output([&] { p.debug_print(onev); }));
Paul's avatar
Paul committed
507
508
    auto end_out     = migraphx::trim(capture_output([&] { p.debug_print(p.end()); }));
    auto p2_ins_out  = migraphx::trim(capture_output([&] { p.debug_print(one2); }));
Paul's avatar
Paul committed
509
510
511
512
513
514
515

    EXPECT(program_out == ins_out);
    EXPECT(inss_out == ins_out);
    EXPECT(end_out == "End instruction");
    EXPECT(p2_ins_out == "Instruction not part of program");
}

Paul's avatar
Paul committed
516
int main(int argc, const char* argv[]) { test::run(argc, argv); }