eval_test.cpp 10.6 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
#include "test.hpp"
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
struct id_target
{
Paul's avatar
Paul committed
11
12
13
14
15
    struct context
    {
        void finish() const {}
    };
    migraphx::context ctx = context{};
Paul's avatar
Paul committed
16
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
17
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
Paul's avatar
Paul committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    migraphx::context get_context() const { return ctx; }
};

struct id_ctx_op
{
    std::string name() const { return "id_ctx_op"; }
    migraphx::argument
    compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct id_ctx_final_op
{
    std::string name() const { return "id_ctx_final_op"; }
Paul's avatar
Paul committed
44
    migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
45
46
47
48
49
50
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
51
52
53
    void finalize(id_target::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
    {
    }
Paul's avatar
Paul committed
54
55
56
57
58
59
60
61

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
62
63
};

Paul's avatar
Paul committed
64
65
struct reverse_pass
{
Paul's avatar
Paul committed
66
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
    void apply(migraphx::program& p) const { std::reverse(p.begin(), p.end()); }
Paul's avatar
Paul committed
69
70
71
72
73
74
75
76
77
78
79
80
81
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
    migraphx::context get_context() const { return {}; }
};

struct invert_pass
{
    std::string name() const { return "invert_pass"; }

Paul's avatar
Paul committed
82
    void apply(migraphx::program& p) const
Paul's avatar
Paul committed
83
    {
Paul's avatar
Paul committed
84
        for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
85
        {
86
            if(ins->name() == "sum")
Paul's avatar
Paul committed
87
            {
88
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
89
            }
90
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
91
            {
92
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
93
94
95
96
97
            }
        }
    }
};

Paul's avatar
Paul committed
98
struct invert_target
Paul's avatar
Paul committed
99
{
Paul's avatar
Paul committed
100
101
    std::string name() const { return "invert"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {invert_pass{}}; }
Paul's avatar
Paul committed
102
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
103
104
};

Paul's avatar
Paul committed
105
struct double_invert_target
Paul's avatar
Paul committed
106
{
Paul's avatar
Paul committed
107
    std::string name() const { return "double_invert"; }
Paul's avatar
Paul committed
108
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
109
    {
Paul's avatar
Paul committed
110
        return {invert_pass{}, invert_pass{}};
Paul's avatar
Paul committed
111
    }
Paul's avatar
Paul committed
112
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
113
114
};

Paul's avatar
Paul committed
115
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
116
{
Paul's avatar
Paul committed
117
    migraphx::program p;
Paul's avatar
Paul committed
118
119
120

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
121
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
122
    auto result = p.eval({});
Paul's avatar
Paul committed
123
124
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
125
126
}

Paul's avatar
Paul committed
127
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
128
{
Paul's avatar
Paul committed
129
    migraphx::program p;
Paul's avatar
Paul committed
130

Paul's avatar
Paul committed
131
132
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
133
134
135
136
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
137
138
    EXPECT(result == migraphx::literal{5});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
139
140
}

Paul's avatar
Paul committed
141
TEST_CASE(print_test)
Paul's avatar
Paul committed
142
{
Paul's avatar
Paul committed
143
    migraphx::program p;
Paul's avatar
Paul committed
144

Paul's avatar
Paul committed
145
    auto x   = p.add_parameter("x", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
146
147
148
149
150
151
152
153
154
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
155
TEST_CASE(param_test)
Paul's avatar
Paul committed
156
{
Paul's avatar
Paul committed
157
    migraphx::program p;
Paul's avatar
Paul committed
158

Paul's avatar
Paul committed
159
160
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Paul's avatar
Paul committed
161

Paul's avatar
Paul committed
162
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
163
    auto result = p.eval(
Paul's avatar
Paul committed
164
165
166
        {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
167
168
}

Paul's avatar
Paul committed
169
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
170
{
Paul's avatar
Paul committed
171
    migraphx::program p;
Khalique's avatar
Khalique committed
172

Paul's avatar
Paul committed
173
174
    auto x = p.add_parameter("x", {migraphx::shape::int32_type});
    auto y = p.add_parameter("y", {migraphx::shape::int32_type});
Khalique's avatar
Khalique committed
175
176

    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
177
    EXPECT(test::throws<migraphx::exception>(
Khalique's avatar
Khalique committed
178
        [&] {
Paul's avatar
Paul committed
179
            p.eval({{"x", migraphx::literal{1}.get_argument()}});
Khalique's avatar
Khalique committed
180
        },
181
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
182
183
}

Paul's avatar
Paul committed
184
TEST_CASE(get_param1)
Paul's avatar
Paul committed
185
186
{
    migraphx::program p;
Paul's avatar
Paul committed
187
188
189
190
191
192
193
194
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
    p.add_instruction(sum_op{}, x, y);
    EXPECT(bool{p.get_parameter("x") == x});
    EXPECT(bool{p.get_parameter("y") == y});
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
195

Paul's avatar
Paul committed
196
197
198
199
200
201
202
203
TEST_CASE(get_param2)
{
    migraphx::program p;
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    EXPECT(bool{p.get_parameter("nonexistent") == p.end()});
}
Paul's avatar
Paul committed
204

Paul's avatar
Paul committed
205
206
207
208
209
210
TEST_CASE(get_param_shapes)
{
    migraphx::program p;
    migraphx::shape s{migraphx::shape::int32_type, {1, 2}};
    auto x = p.add_parameter("x", s);
    auto y = p.add_parameter("y", s);
Paul's avatar
Paul committed
211
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
212
213
214
215
    auto m = p.get_parameter_shapes();
    EXPECT(m.count("nonexistent") == 0);
    EXPECT(m.at("x") == s);
    EXPECT(m.at("y") == s);
Paul's avatar
Paul committed
216
217
}

Paul's avatar
Paul committed
218
TEST_CASE(replace_test)
Paul's avatar
Paul committed
219
{
Paul's avatar
Paul committed
220
    migraphx::program p;
Paul's avatar
Paul committed
221
222
223
224
225

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
226
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
227
228

    auto result = p.eval({});
Paul's avatar
Paul committed
229
230
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
231
232
}

Paul's avatar
Paul committed
233
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
234
{
Paul's avatar
Paul committed
235
    migraphx::program p;
Paul's avatar
Paul committed
236

Paul's avatar
Paul committed
237
238
239
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
240
241
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
242
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
243
244

    auto result = p.eval({});
Paul's avatar
Paul committed
245
246
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
247
248
}

Paul's avatar
Paul committed
249
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
250
{
Paul's avatar
Paul committed
251
    migraphx::program p;
Paul's avatar
Paul committed
252

Paul's avatar
Paul committed
253
254
255
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
256
257
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
258
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
259
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
260
261

    auto result = p.eval({});
Paul's avatar
Paul committed
262
263
    EXPECT(result == migraphx::literal{2});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
264
265
}

Paul's avatar
Paul committed
266
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
267
{
Paul's avatar
Paul committed
268
    migraphx::program p;
Paul's avatar
Paul committed
269

Paul's avatar
Paul committed
270
271
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
272
273
274
275
276
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
277
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
278
279

    auto result = p.eval({});
Paul's avatar
Paul committed
280
281
    EXPECT(result == migraphx::literal{4});
    EXPECT(result != migraphx::literal{5});
Paul's avatar
Paul committed
282
283
}

Paul's avatar
Paul committed
284
285
286
287
TEST_CASE(remove_test1)
{
    migraphx::program p;

Paul's avatar
Paul committed
288
289
290
    auto one     = p.add_literal(1);
    auto two     = p.add_literal(2);
    auto sum     = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
291
292
293
294
295
296
297
298
299
300
301
302
303
    auto removed = p.add_instruction(minus_op{}, sum, one);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});

    auto result = p.eval({});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

TEST_CASE(remove_test2)
{
    migraphx::program p;

Paul's avatar
Paul committed
304
305
    auto one     = p.add_literal(1);
    auto two     = p.add_literal(2);
Paul's avatar
Paul committed
306
307
308
309
310
311
312
313
314
315
    auto removed = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(sum_op{}, one, two);
    p.remove_instruction(removed);
    EXPECT(bool{p.validate() == p.end()});

    auto result = p.eval({});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{1});
}

Paul's avatar
Paul committed
316
TEST_CASE(target_test)
Paul's avatar
Paul committed
317
{
Paul's avatar
Paul committed
318
    migraphx::program p;
Paul's avatar
Paul committed
319
320
321
322
323
324

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
325
326
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
327
328
}

Paul's avatar
Paul committed
329
TEST_CASE(invert_target_test)
Paul's avatar
Paul committed
330
{
Paul's avatar
Paul committed
331
    migraphx::program p;
Paul's avatar
Paul committed
332
333
334
335

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
336
    p.compile(invert_target{});
Paul's avatar
Paul committed
337
    auto result = p.eval({});
Paul's avatar
Paul committed
338
339
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
340
341
}

Paul's avatar
Paul committed
342
TEST_CASE(double_invert_target_test)
Paul's avatar
Paul committed
343
{
Paul's avatar
Paul committed
344
    migraphx::program p;
Paul's avatar
Paul committed
345
346
347
348

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
Paul's avatar
Paul committed
349
    p.compile(double_invert_target{});
Paul's avatar
Paul committed
350
    auto result = p.eval({});
Paul's avatar
Paul committed
351
352
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
353
354
}

Paul's avatar
Paul committed
355
356
// Check that the program doesnt modify the context directly, and only the operators modify the
// context
Paul's avatar
Paul committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
TEST_CASE(eval_context1)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    EXPECT(is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context2)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    // id_ctx_op will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context3)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_final_op{}, one, two);
    p.compile(t);
    // Finalizer will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
    auto ctx = p.get_context();
    p.eval({});
    EXPECT(is_shared(ctx, p.get_context()));
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

Paul's avatar
Paul committed
403
int main(int argc, const char* argv[]) { test::run(argc, argv); }