eval_test.cpp 10.6 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
#include "test.hpp"
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
struct id_target
{
Paul's avatar
Paul committed
11
12
13
14
15
    struct context
    {
        void finish() const {}
    };
    migraphx::context ctx = context{};
Paul's avatar
Paul committed
16
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
17
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
Paul's avatar
Paul committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    migraphx::context get_context() const { return ctx; }
};

struct id_ctx_op
{
    std::string name() const { return "id_ctx_op"; }
    migraphx::argument
    compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct id_ctx_final_op
{
    std::string name() const { return "id_ctx_final_op"; }
Paul's avatar
Paul committed
44
    migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
45
46
47
48
49
50
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
51
52
53
    void finalize(id_target::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
    {
    }
Paul's avatar
Paul committed
54
55
56
57
58
59
60
61

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
62
63
};

Paul's avatar
Paul committed
64
65
struct reverse_pass
{
Paul's avatar
Paul committed
66
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    void apply(migraphx::program& p) const
    {
        std::reverse(p.begin(), p.end());
    }
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
    migraphx::context get_context() const { return {}; }
};

struct invert_pass
{
    std::string name() const { return "invert_pass"; }

Paul's avatar
Paul committed
85
    void apply(migraphx::program& p) const
Paul's avatar
Paul committed
86
    {
Paul's avatar
Paul committed
87
        for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
88
        {
89
            if(ins->name() == "sum")
Paul's avatar
Paul committed
90
            {
91
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
92
            }
93
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
94
            {
95
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
96
97
98
99
100
            }
        }
    }
};

Paul's avatar
Paul committed
101
struct invert_target
Paul's avatar
Paul committed
102
{
Paul's avatar
Paul committed
103
104
    std::string name() const { return "invert"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {invert_pass{}}; }
Paul's avatar
Paul committed
105
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
106
107
};

Paul's avatar
Paul committed
108
struct double_invert_target
Paul's avatar
Paul committed
109
{
Paul's avatar
Paul committed
110
    std::string name() const { return "double_invert"; }
Paul's avatar
Paul committed
111
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
112
    {
Paul's avatar
Paul committed
113
        return {invert_pass{}, invert_pass{}};
Paul's avatar
Paul committed
114
    }
Paul's avatar
Paul committed
115
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
116
117
};

Paul's avatar
Paul committed
118
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
119
{
Paul's avatar
Paul committed
120
    migraphx::program p;
Paul's avatar
Paul committed
121
122
123

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
124
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
125
    auto result = p.eval({});
Paul's avatar
Paul committed
126
127
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
128
129
}

Paul's avatar
Paul committed
130
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
131
{
Paul's avatar
Paul committed
132
    migraphx::program p;
Paul's avatar
Paul committed
133

Paul's avatar
Paul committed
134
135
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
136
137
138
139
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
140
141
    EXPECT(result == migraphx::literal{5});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
142
143
}

Paul's avatar
Paul committed
144
TEST_CASE(print_test)
Paul's avatar
Paul committed
145
{
Paul's avatar
Paul committed
146
    migraphx::program p;
Paul's avatar
Paul committed
147

Paul's avatar
Paul committed
148
    auto x   = p.add_parameter("x", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
149
150
151
152
153
154
155
156
157
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
158
TEST_CASE(param_test)
Paul's avatar
Paul committed
159
{
Paul's avatar
Paul committed
160
    migraphx::program p;
Paul's avatar
Paul committed
161

Paul's avatar
Paul committed
162
163
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
164

Paul's avatar
Paul committed
165
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
166
    auto result = p.eval(
Paul's avatar
Paul committed
167
168
169
        {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
170
171
}

Paul's avatar
Paul committed
172
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
173
{
Paul's avatar
Paul committed
174
    migraphx::program p;
Khalique's avatar
Khalique committed
175

Paul's avatar
Paul committed
176
177
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Khalique's avatar
Khalique committed
178
179

    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
180
    EXPECT(test::throws<migraphx::exception>(
Khalique's avatar
Khalique committed
181
        [&] {
Paul's avatar
Paul committed
182
            p.eval({{"x", migraphx::literal{1}.get_argument()}});
Khalique's avatar
Khalique committed
183
        },
184
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
185
186
}

Paul's avatar
Paul committed
187
TEST_CASE(get_param1)
Paul's avatar
Paul committed
188
189
{
    migraphx::program p;
Paul's avatar
Paul committed
190
191
192
193
194
195
196
197
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    p.add_instruction(sum_op{}, x, y);
    EXPECT(bool{p.get_parameter("x") == x});
    EXPECT(bool{p.get_parameter("y") == y});
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
198

Paul's avatar
Paul committed
199
200
201
202
203
204
205
206
TEST_CASE(get_param2)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
207

Paul's avatar
Paul committed
208
209
210
211
212
213
TEST_CASE(get_param_shapes)
{
    migraphx::program p;
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
Paul's avatar
Paul committed
214
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
215
216
217
218
    auto m = p.get_parameter_shapes();
    EXPECT(m.count("nonexistent") == 0);
    EXPECT(m.at("x") == s);
    EXPECT(m.at("y") == s);
Paul's avatar
Paul committed
219
220
}

Paul's avatar
Paul committed
221
TEST_CASE(replace_test)
Paul's avatar
Paul committed
222
{
Paul's avatar
Paul committed
223
    migraphx::program p;
Paul's avatar
Paul committed
224
225
226
227
228

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
229
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
230
231

    auto result = p.eval({});
Paul's avatar
Paul committed
232
233
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
234
235
}

Paul's avatar
Paul committed
236
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
237
{
Paul's avatar
Paul committed
238
    migraphx::program p;
Paul's avatar
Paul committed
239

Paul's avatar
Paul committed
240
241
242
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
243
244
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
245
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
246
247

    auto result = p.eval({});
Paul's avatar
Paul committed
248
249
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
250
251
}

Paul's avatar
Paul committed
252
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
253
{
Paul's avatar
Paul committed
254
    migraphx::program p;
Paul's avatar
Paul committed
255

Paul's avatar
Paul committed
256
257
258
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
259
260
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
261
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
262
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
263
264

    auto result = p.eval({});
Paul's avatar
Paul committed
265
266
    EXPECT(result == migraphx::literal{2});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
267
268
}

Paul's avatar
Paul committed
269
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
270
{
Paul's avatar
Paul committed
271
    migraphx::program p;
Paul's avatar
Paul committed
272

Paul's avatar
Paul committed
273
274
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
275
276
277
278
279
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
280
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
281
282

    auto result = p.eval({});
Paul's avatar
Paul committed
283
284
    EXPECT(result == migraphx::literal{4});
    EXPECT(result != migraphx::literal{5});
Paul's avatar
Paul committed
285
286
}

Paul's avatar
Paul committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
TEST_CASE(remove_test1)
{
    migraphx::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    auto removed = p.add_instruction(minus_op{}, sum, one);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});


    auto result = p.eval({});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

TEST_CASE(remove_test2)
{
    migraphx::program p;

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto removed = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(sum_op{}, one, two);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});


    auto result = p.eval({});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

Paul's avatar
Paul committed
321
TEST_CASE(target_test)
Paul's avatar
Paul committed
322
{
Paul's avatar
Paul committed
323
    migraphx::program p;
Paul's avatar
Paul committed
324
325
326
327
328
329

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
330
331
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
332
333
}

Paul's avatar
Paul committed
334
TEST_CASE(invert_target_test)
Paul's avatar
Paul committed
335
{
Paul's avatar
Paul committed
336
    migraphx::program p;
Paul's avatar
Paul committed
337
338
339
340

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
341
    p.compile(invert_target{});
Paul's avatar
Paul committed
342
    auto result = p.eval({});
Paul's avatar
Paul committed
343
344
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
345
346
}

Paul's avatar
Paul committed
347
TEST_CASE(double_invert_target_test)
Paul's avatar
Paul committed
348
{
Paul's avatar
Paul committed
349
    migraphx::program p;
Paul's avatar
Paul committed
350
351
352
353

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
354
    p.compile(double_invert_target{});
Paul's avatar
Paul committed
355
    auto result = p.eval({});
Paul's avatar
Paul committed
356
357
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
358
359
}

Paul's avatar
Paul committed
360
361
// Check that the program doesnt modify the context directly, and only the operators modify the
// context
Paul's avatar
Paul committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
TEST_CASE(eval_context1)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    EXPECT(is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context2)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    // id_ctx_op will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context3)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_final_op{}, one, two);
    p.compile(t);
    // Finalizer will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
    auto ctx = p.get_context();
    p.eval({});
    EXPECT(is_shared(ctx, p.get_context()));
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

Paul's avatar
Paul committed
408
int main(int argc, const char* argv[]) { test::run(argc, argv); }