eval_test.cpp 8.57 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
5
#include <sstream>
Paul's avatar
Paul committed
6
#include "test.hpp"
Paul's avatar
Paul committed
7
#include <basic_ops.hpp>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
struct id_target
{
Paul's avatar
Paul committed
11
12
13
    struct context
    {
        void finish() const {}
mei-ye's avatar
mei-ye committed
14
15
16
17
        void set_stream(int) {}
        void create_events(int) {}
        void record_event(int) {}
        void wait_event(int) {}
Paul's avatar
Paul committed
18
19
    };
    migraphx::context ctx = context{};
Paul's avatar
Paul committed
20
    std::string name() const { return "id"; }
Paul's avatar
Paul committed
21
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {}; }
Paul's avatar
Paul committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    migraphx::context get_context() const { return ctx; }
};

struct id_ctx_op
{
    std::string name() const { return "id_ctx_op"; }
    migraphx::argument
    compute(id_target::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

struct id_ctx_final_op
{
    std::string name() const { return "id_ctx_final_op"; }
Paul's avatar
Paul committed
48
    migraphx::argument compute(const migraphx::shape&, std::vector<migraphx::argument> args) const
Paul's avatar
Paul committed
49
50
51
52
53
54
    {
        if(args.empty())
            return {};
        return args.front();
    }

Paul's avatar
Paul committed
55
56
57
    void finalize(id_target::context&, const migraphx::shape&, const std::vector<migraphx::shape>&)
    {
    }
Paul's avatar
Paul committed
58
59
60
61
62
63
64
65

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
Paul's avatar
Paul committed
66
67
};

Paul's avatar
Paul committed
68
69
struct reverse_pass
{
Paul's avatar
Paul committed
70
    std::string name() const { return "reverse_pass"; }
Paul's avatar
Paul committed
71

Paul's avatar
Paul committed
72
    void apply(migraphx::program& p) const
Paul's avatar
Paul committed
73
    {
Paul's avatar
Paul committed
74
        for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
75
        {
76
            if(ins->name() == "sum")
Paul's avatar
Paul committed
77
            {
78
                p.replace_instruction(ins, minus_op{}, ins->inputs());
Paul's avatar
Paul committed
79
            }
80
            else if(ins->name() == "minus")
Paul's avatar
Paul committed
81
            {
82
                p.replace_instruction(ins, sum_op{}, ins->inputs());
Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
            }
        }
    }
};

struct reverse_target
{
    std::string name() const { return "reverse"; }
Paul's avatar
Paul committed
91
92
    std::vector<migraphx::pass> get_passes(migraphx::context&) const { return {reverse_pass{}}; }
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
93
94
95
96
97
};

struct double_reverse_target
{
    std::string name() const { return "double_reverse"; }
Paul's avatar
Paul committed
98
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
Paul's avatar
Paul committed
99
100
    {
        return {reverse_pass{}, reverse_pass{}};
Paul's avatar
Paul committed
101
    }
Paul's avatar
Paul committed
102
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
103
104
};

Paul's avatar
Paul committed
105
TEST_CASE(literal_test1)
Paul's avatar
Paul committed
106
{
Paul's avatar
Paul committed
107
    migraphx::program p;
Paul's avatar
Paul committed
108
109
110

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
Paul's avatar
Paul committed
111
    p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
112
    auto result = p.eval({});
Paul's avatar
Paul committed
113
114
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
115
116
}

Paul's avatar
Paul committed
117
TEST_CASE(literal_test2)
Paul's avatar
Paul committed
118
{
Paul's avatar
Paul committed
119
    migraphx::program p;
Paul's avatar
Paul committed
120

Paul's avatar
Paul committed
121
122
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
123
124
125
126
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto result = p.eval({});
Paul's avatar
Paul committed
127
128
    EXPECT(result == migraphx::literal{5});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
129
130
}

Paul's avatar
Paul committed
131
TEST_CASE(print_test)
Paul's avatar
Paul committed
132
{
Paul's avatar
Paul committed
133
    migraphx::program p;
Paul's avatar
Paul committed
134

Paul's avatar
Paul committed
135
    auto x   = p.add_parameter("x", {migraphx::shape::int64_type});
Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, x, two);

    std::stringstream ss;
    ss << p;
    std::string s = ss.str();
    EXPECT(!s.empty());
}

Paul's avatar
Paul committed
145
TEST_CASE(param_test)
Paul's avatar
Paul committed
146
{
Paul's avatar
Paul committed
147
    migraphx::program p;
Paul's avatar
Paul committed
148

Paul's avatar
Paul committed
149
150
    auto x = p.add_parameter("x", {migraphx::shape::int64_type});
    auto y = p.add_parameter("y", {migraphx::shape::int64_type});
Paul's avatar
Paul committed
151

Paul's avatar
Paul committed
152
    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
153
    auto result = p.eval(
Paul's avatar
Paul committed
154
155
156
        {{"x", migraphx::literal{1}.get_argument()}, {"y", migraphx::literal{2}.get_argument()}});
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
157
158
}

Paul's avatar
Paul committed
159
TEST_CASE(param_error_test)
Khalique's avatar
Khalique committed
160
{
Paul's avatar
Paul committed
161
    migraphx::program p;
Khalique's avatar
Khalique committed
162

Paul's avatar
Paul committed
163
164
    auto x = p.add_parameter("x", {migraphx::shape::int64_type});
    auto y = p.add_parameter("y", {migraphx::shape::int64_type});
Khalique's avatar
Khalique committed
165
166

    p.add_instruction(sum_op{}, x, y);
Paul's avatar
Paul committed
167
    EXPECT(test::throws<migraphx::exception>(
Khalique's avatar
Khalique committed
168
        [&] {
Paul's avatar
Paul committed
169
            p.eval({{"x", migraphx::literal{1}.get_argument()}});
Khalique's avatar
Khalique committed
170
        },
171
        "Parameter not found: y"));
Khalique's avatar
Khalique committed
172
173
}

Paul's avatar
Paul committed
174
TEST_CASE(replace_test)
Paul's avatar
Paul committed
175
{
Paul's avatar
Paul committed
176
    migraphx::program p;
Paul's avatar
Paul committed
177
178
179
180
181

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    auto sum = p.add_instruction(sum_op{}, one, two);
    p.replace_instruction(sum, minus_op{}, two, one);
Paul's avatar
Paul committed
182
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
183
184

    auto result = p.eval({});
Paul's avatar
Paul committed
185
186
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
187
188
}

Paul's avatar
Paul committed
189
TEST_CASE(replace_ins_test)
Paul's avatar
Paul committed
190
{
Paul's avatar
Paul committed
191
    migraphx::program p;
Paul's avatar
Paul committed
192

Paul's avatar
Paul committed
193
194
195
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
196
197
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.replace_instruction(sum, minus);
Paul's avatar
Paul committed
198
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
199
200

    auto result = p.eval({});
Paul's avatar
Paul committed
201
202
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
203
204
}

Paul's avatar
Paul committed
205
TEST_CASE(replace_ins_test2)
Paul's avatar
Paul committed
206
{
Paul's avatar
Paul committed
207
    migraphx::program p;
Paul's avatar
Paul committed
208

Paul's avatar
Paul committed
209
210
211
    auto one   = p.add_literal(1);
    auto two   = p.add_literal(2);
    auto sum   = p.add_instruction(sum_op{}, one, two);
Paul's avatar
Paul committed
212
213
    auto minus = p.add_instruction(minus_op{}, two, one);
    p.add_instruction(pass_op{}, minus);
Paul's avatar
Paul committed
214
    p.replace_instruction(two, sum);
Paul's avatar
Paul committed
215
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
216
217

    auto result = p.eval({});
Paul's avatar
Paul committed
218
219
    EXPECT(result == migraphx::literal{2});
    EXPECT(result != migraphx::literal{3});
Paul's avatar
Paul committed
220
221
}

Paul's avatar
Paul committed
222
TEST_CASE(insert_replace_test)
Paul's avatar
Paul committed
223
{
Paul's avatar
Paul committed
224
    migraphx::program p;
Paul's avatar
Paul committed
225

Paul's avatar
Paul committed
226
227
    auto one  = p.add_literal(1);
    auto two  = p.add_literal(2);
Paul's avatar
Paul committed
228
229
230
231
232
    auto sum1 = p.add_instruction(sum_op{}, one, two);
    p.add_instruction(sum_op{}, sum1, two);

    auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
    p.replace_instruction(sum1, minus_op{}, sum0, two);
Paul's avatar
Paul committed
233
    EXPECT(bool{p.validate() == p.end()});
Paul's avatar
Paul committed
234
235

    auto result = p.eval({});
Paul's avatar
Paul committed
236
237
    EXPECT(result == migraphx::literal{4});
    EXPECT(result != migraphx::literal{5});
Paul's avatar
Paul committed
238
239
}

Paul's avatar
Paul committed
240
TEST_CASE(target_test)
Paul's avatar
Paul committed
241
{
Paul's avatar
Paul committed
242
    migraphx::program p;
Paul's avatar
Paul committed
243
244
245
246
247
248

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(id_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
249
250
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
251
252
}

Paul's avatar
Paul committed
253
TEST_CASE(reverse_target_test)
Paul's avatar
Paul committed
254
{
Paul's avatar
Paul committed
255
    migraphx::program p;
Paul's avatar
Paul committed
256
257
258
259
260
261

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(reverse_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
262
263
    EXPECT(result == migraphx::literal{1});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
264
265
}

Paul's avatar
Paul committed
266
TEST_CASE(double_reverse_target_test)
Paul's avatar
Paul committed
267
{
Paul's avatar
Paul committed
268
    migraphx::program p;
Paul's avatar
Paul committed
269
270
271
272
273
274

    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, two, one);
    p.compile(double_reverse_target{});
    auto result = p.eval({});
Paul's avatar
Paul committed
275
276
    EXPECT(result == migraphx::literal{3});
    EXPECT(result != migraphx::literal{4});
Paul's avatar
Paul committed
277
278
}

Paul's avatar
Paul committed
279
280
// Check that the program doesnt modify the context directly, and only the operators modify the
// context
Paul's avatar
Paul committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
TEST_CASE(eval_context1)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(sum_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    EXPECT(is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context2)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_op{}, one, two);
    p.compile(t);
    EXPECT(is_shared(t.ctx, p.get_context()));
    p.eval({});
    // id_ctx_op will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

TEST_CASE(eval_context3)
{
    migraphx::program p;
    id_target t{};
    EXPECT(is_shared(t.ctx, t.get_context()));
    auto one = p.add_literal(1);
    auto two = p.add_literal(2);
    p.add_instruction(id_ctx_final_op{}, one, two);
    p.compile(t);
    // Finalizer will modify the context
    EXPECT(not is_shared(t.ctx, p.get_context()));
    auto ctx = p.get_context();
    p.eval({});
    EXPECT(is_shared(ctx, p.get_context()));
    EXPECT(not is_shared(t.ctx, p.get_context()));
}

Paul's avatar
Paul committed
327
int main(int argc, const char* argv[]) { test::run(argc, argv); }