"vscode:/vscode.git/clone" did not exist on "9a4e7e7f09dfd2ca6538d6e2af6d7310e521ff9f"
schedule_test.cpp 17.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/schedule.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <basic_ops.hpp>
#include <test.hpp>

struct unary_op
{
    std::string name() const { return "unary"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

Paul's avatar
Paul committed
31
struct nary_op
Paul's avatar
Paul committed
32
{
Paul's avatar
Paul committed
33
    std::string name() const { return "nary"; }
Paul's avatar
Paul committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
};

Paul's avatar
Paul committed
50
51
struct wait_event
{
Paul's avatar
Paul committed
52
53
    std::shared_ptr<std::vector<std::size_t>> wait_for =
        std::make_shared<std::vector<std::size_t>>();
Paul's avatar
Paul committed
54
55
56
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
57
        return migraphx::pack(f(*self.wait_for, "wait_for"));
Paul's avatar
Paul committed
58
59
60
61
    }
    std::string name() const { return "wait_event"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }

Paul's avatar
Paul committed
62
63
64
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
65
    {
Paul's avatar
Paul committed
66
67
        assert(wait_for != nullptr);
        assert(not wait_for->empty());
Paul's avatar
Paul committed
68
69
70
71
        return {};
    }
};

Paul's avatar
Paul committed
72
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
Paul's avatar
Paul committed
73
74
using wait_map =
    std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
Paul's avatar
Paul committed
75
76
77

struct schedule_model_test
{
Paul's avatar
Paul committed
78
    std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
Paul's avatar
Paul committed
79
80
    std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream =
        std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
Paul's avatar
Paul committed
81
    std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
Paul's avatar
Paul committed
82
    std::size_t concurrency() const { return 4; }
Paul's avatar
Paul committed
83
    void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
Paul's avatar
Paul committed
84
85
86
    {
        (*ins2stream)[ins] = n;
    }
Paul's avatar
Paul committed
87
    void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
Paul's avatar
Paul committed
88
    {
Paul's avatar
Paul committed
89
        if(ins2wait_for->count(ins) == 0)
Paul's avatar
Paul committed
90
91
92
93
94
95
96
        {
            auto event = wait_event{};
            p.insert_instruction(ins, event);
            (*ins2wait_for)[ins] = event.wait_for;
        }
        (*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
    }
Paul's avatar
Paul committed
97
    void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
Paul's avatar
Paul committed
98
99
    {
        (*wait2stream)[wait_id] = ins2stream->at(ins);
Paul's avatar
Paul committed
100
101
102
    }
    std::size_t weight(const migraphx::operation& op) const
    {
Paul's avatar
Paul committed
103
        if(op.name() == "binary" or op.name() == "unary")
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
            return 4;
        else
            return 1;
    }
};

struct schedule_target
{
Paul's avatar
Paul committed
112
    schedule_model_test model{};
Paul's avatar
Paul committed
113
114
115
    std::string name() const { return "schedule"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
    {
Paul's avatar
Paul committed
116
        return {migraphx::schedule{model}};
Paul's avatar
Paul committed
117
118
    }
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
    std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
    bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
Paul's avatar
Paul committed
123
124
125
126
};

bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
Paul's avatar
Paul committed
127
    for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
128
    {
Paul's avatar
Paul committed
129
        if(ins->name() != "identity")
Paul's avatar
Paul committed
130
            continue;
Paul's avatar
Paul committed
131
        if(not migraphx::contains(ins->inputs(), x))
Paul's avatar
Paul committed
132
            continue;
Paul's avatar
Paul committed
133
        if(not migraphx::contains(ins->inputs(), y))
Paul's avatar
Paul committed
134
135
            continue;
        return true;
Paul's avatar
Paul committed
136
137
138
139
    }
    return false;
}

Paul's avatar
Paul committed
140
void check_conflicts(migraphx::program& p,
Paul's avatar
Paul committed
141
142
                     std::vector<std::vector<migraphx::instruction_ref>> conflicts,
                     bool result = true)
Paul's avatar
Paul committed
143
144
{
    migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
145
        if(i == j)
Paul's avatar
Paul committed
146
            return;
Paul's avatar
Paul committed
147
148
        for(auto ins1 : conflicts[i])
            for(auto ins2 : conflicts[j])
Paul's avatar
Paul committed
149
                CHECK(check_conflicts(p, ins1, ins2) == result);
Paul's avatar
Paul committed
150
151
152
    });
}

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
std::vector<T> sorted(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    return x;
}

Paul's avatar
Paul committed
160
template <class T>
Paul's avatar
Paul committed
161
162
163
164
165
166
167
168
169
170
171
172
173
std::vector<T> unique(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    x.erase(std::unique(x.begin(), x.end()), x.end());
    return x;
}

std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

Paul's avatar
Paul committed
174
175
176
177
178
179
180
181
182
183
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
    wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
    auto wait_ins = std::prev(ins);
Paul's avatar
Paul committed
184
185
186
    // Skip identity operators
    while(wait_ins->name() == "identity")
        wait_ins = std::prev(wait_ins);
Paul's avatar
Paul committed
187
    if(wait_ins->name() != "wait_event")
Paul's avatar
Paul committed
188
        return {};
Paul's avatar
Paul committed
189
    auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
Paul's avatar
Paul committed
190
191
192
193
    std::sort(wf.begin(), wf.end());
    return wf;
}

Paul's avatar
Paul committed
194
195
196
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
Paul's avatar
Paul committed
197
198
{
    std::vector<migraphx::instruction_ref> result;
Paul's avatar
Paul committed
199
    for(std::size_t i = 0; i < n; i++)
Paul's avatar
Paul committed
200
201
202
203
204
205
    {
        result.push_back(p.add_instruction(x, input));
        input = result.back();
    }
    return result;
}
Paul's avatar
Paul committed
206
207
TEST_CASE(single_entry)
{
Paul's avatar
Paul committed
208
    schedule_target t{};
Paul's avatar
Paul committed
209
210
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
211
212
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
213
    auto binary = p.add_instruction(nary_op{}, onep1, onep2);
Paul's avatar
Paul committed
214
215
216
217
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
218
219
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
220
221
222
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
TEST_CASE(zero_record)
{
    schedule_target t{};
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.has_stream(binary));
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
242
243
TEST_CASE(zero_merge1)
{
Paul's avatar
Paul committed
244
    schedule_target t{};
Paul's avatar
Paul committed
245
246
247
248
249
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
Paul's avatar
Paul committed
250
251
252
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
253
    // No stream assignment
Paul's avatar
Paul committed
254
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
255
256
257
258
259
260
261
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge2)
{
Paul's avatar
Paul committed
262
    schedule_target t{};
Paul's avatar
Paul committed
263
264
265
266
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
267
268
269
    auto binary = p.add_instruction(migraphx::op::identity{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
Paul's avatar
Paul committed
270
271
272
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
273
    // No stream assignment
Paul's avatar
Paul committed
274
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
275
276
277
278
279
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
280
281
282
283
TEST_CASE(zero_merge3)
{
    schedule_target t{};
    migraphx::program p;
Paul's avatar
Paul committed
284
285
286
287
    auto one   = p.add_literal(1);
    auto onep1 = p.add_instruction(unary_op{}, one);
    auto onep2 = p.add_instruction(unary_op{}, one);
    auto id    = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
Paul's avatar
Paul committed
288
289
290
291
292
293
294
295
296
297
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
Paul's avatar
Paul committed
298
299
    EXPECT(get_wait_for(final) ==
           get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
300
301
302
303
304
305
306
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge4)
{
    schedule_target t{};
    migraphx::program p;
Paul's avatar
Paul committed
307
308
309
310
311
312
    auto one   = p.add_literal(1);
    auto onep1 = p.add_instruction(unary_op{}, one);
    auto onep2 = p.add_instruction(unary_op{}, one);
    auto id    = p.add_instruction(migraphx::op::identity{},
                                p.add_instruction(migraphx::op::identity{}, onep1),
                                p.add_instruction(migraphx::op::identity{}, onep2));
Paul's avatar
Paul committed
313
314
315
316
317
318
319
320
321
322
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
Paul's avatar
Paul committed
323
324
    EXPECT(get_wait_for(final) ==
           get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
325
326
327
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
328
TEST_CASE(double_entry)
Paul's avatar
Paul committed
329
{
Paul's avatar
Paul committed
330
    schedule_target t{};
Paul's avatar
Paul committed
331
    migraphx::program p;
Paul's avatar
Paul committed
332
333
334
335
    auto one    = p.add_literal(1);
    auto two    = p.add_literal(2);
    auto onep   = p.add_instruction(unary_op{}, one);
    auto twop   = p.add_instruction(unary_op{}, two);
Paul's avatar
Paul committed
336
    auto binary = p.add_instruction(nary_op{}, onep, twop);
Paul's avatar
Paul committed
337
338
339
340
341
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(not t.has_stream(two));
    EXPECT(t.get_stream(onep) != t.get_stream(twop));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
342
343
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
Paul's avatar
Paul committed
344
    EXPECT(check_conflicts(p, onep, twop));
Paul's avatar
Paul committed
345
346
}

Paul's avatar
Paul committed
347
TEST_CASE(two_branches)
Paul's avatar
Paul committed
348
{
Paul's avatar
Paul committed
349
    schedule_target t{};
Paul's avatar
Paul committed
350
351
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
352
353
    auto c1     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
354
    auto binary = p.add_instruction(nary_op{}, i1, c1.back());
Paul's avatar
Paul committed
355
356
357
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 1);
Paul's avatar
Paul committed
358
    for(auto ins : c1)
Paul's avatar
Paul committed
359
360
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
361
362
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
363
364
365
    check_conflicts(p, {c1, {i1}});
}

Paul's avatar
Paul committed
366
TEST_CASE(four_branches)
Paul's avatar
Paul committed
367
{
Paul's avatar
Paul committed
368
    schedule_target t{};
Paul's avatar
Paul committed
369
370
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
371
372
373
374
    auto c1     = chain(p, 4, unary_op{}, one);
    auto c2     = chain(p, 3, unary_op{}, one);
    auto c3     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
375
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
Paul's avatar
Paul committed
376
377
378
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
379
    for(auto ins : c1)
Paul's avatar
Paul committed
380
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
381
    for(auto ins : c2)
Paul's avatar
Paul committed
382
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
383
    for(auto ins : c3)
Paul's avatar
Paul committed
384
385
        EXPECT(t.get_stream(ins) == 2);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
386
387
388
389
390
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
391
    check_conflicts(p, {c1, c2, c3, {i1}});
Paul's avatar
Paul committed
392
393
}

Paul's avatar
Paul committed
394
TEST_CASE(five_branches)
Paul's avatar
Paul committed
395
{
Paul's avatar
Paul committed
396
    schedule_target t{};
Paul's avatar
Paul committed
397
398
399
400
401
402
403
404
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto c1     = chain(p, 5, unary_op{}, one);
    auto c2     = chain(p, 4, unary_op{}, one);
    auto c3     = chain(p, 3, unary_op{}, one);
    auto c4     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
Paul's avatar
Paul committed
405
406
407
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
408
    for(auto ins : c1)
Paul's avatar
Paul committed
409
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
410
    for(auto ins : c2)
Paul's avatar
Paul committed
411
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
412
    for(auto ins : c3)
Paul's avatar
Paul committed
413
        EXPECT(t.get_stream(ins) == 2);
Paul's avatar
Paul committed
414
    for(auto ins : c4)
Paul's avatar
Paul committed
415
416
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
417
418
419
420
421
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
422
423
424
425
    check_conflicts(p, {c1, c2, c3, c4});
    check_conflicts(p, {c1, c2, c3, {i1}});
}

Paul's avatar
Paul committed
426
427
TEST_CASE(four_branches_eq)
{
Paul's avatar
Paul committed
428
    schedule_target t{};
Paul's avatar
Paul committed
429
430
431
432
433
434
435
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto onep3  = p.add_instruction(unary_op{}, one);
    auto onep4  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
Paul's avatar
Paul committed
436
437
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
438
439
440
441
442
    EXPECT(
        sorted<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
        unique<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
443
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
444
445
    EXPECT(
        get_wait_for(binary) ==
Paul's avatar
Paul committed
446
447
448
        get_wait_for(
            t.get_stream(binary),
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
449
450
451
    check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}

Paul's avatar
Paul committed
452
453
TEST_CASE(seq_merge)
{
Paul's avatar
Paul committed
454
    schedule_target t{};
Paul's avatar
Paul committed
455
    migraphx::program p;
Paul's avatar
Paul committed
456
457
458
    auto one     = p.add_literal(1);
    auto c1      = chain(p, 2, unary_op{}, one);
    auto i1      = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
459
460
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
461
462
    auto c2      = chain(p, 2, unary_op{}, binary1);
    auto i2      = p.add_instruction(unary_op{}, binary1);
Paul's avatar
Paul committed
463
464
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

Paul's avatar
Paul committed
465
466
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
467

Paul's avatar
Paul committed
468
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
469
    for(auto ins : c1)
Paul's avatar
Paul committed
470
471
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary1) == 3);
Paul's avatar
Paul committed
472
473
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
474
475
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
476
    EXPECT(t.get_stream(i2) == 3);
Paul's avatar
Paul committed
477
    for(auto ins : c2)
Paul's avatar
Paul committed
478
479
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary2) == 0);
Paul's avatar
Paul committed
480
481
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
482
483
484
485
486
    check_conflicts(p, {c2, {i2}});
}

TEST_CASE(par_merge)
{
Paul's avatar
Paul committed
487
    schedule_target t{};
Paul's avatar
Paul committed
488
    migraphx::program p;
Paul's avatar
Paul committed
489
490
491
492
    auto one     = p.add_literal(1);
    auto start1  = p.add_instruction(unary_op{}, one);
    auto c1      = chain(p, 3, unary_op{}, start1);
    auto i1      = p.add_instruction(unary_op{}, start1);
Paul's avatar
Paul committed
493
494
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
495
496
497
    auto start2  = p.add_instruction(unary_op{}, one);
    auto c2      = chain(p, 2, unary_op{}, start2);
    auto i2      = p.add_instruction(unary_op{}, start2);
Paul's avatar
Paul committed
498
499
500
501
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

    auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);

Paul's avatar
Paul committed
502
503
504
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(binary3) == 0);
Paul's avatar
Paul committed
505

Paul's avatar
Paul committed
506
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
507
    for(auto ins : c1)
Paul's avatar
Paul committed
508
509
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary1) == 0);
Paul's avatar
Paul committed
510
511
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
512
513
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
514
    EXPECT(t.get_stream(i2) == 1);
Paul's avatar
Paul committed
515
    for(auto ins : c2)
Paul's avatar
Paul committed
516
517
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary2) == 3);
Paul's avatar
Paul committed
518
519
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
520
521
522
523
    check_conflicts(p, {c2, {i2}});

    EXPECT(check_conflicts(p, binary1, binary2));
}
Paul's avatar
Paul committed
524
int main(int argc, const char* argv[]) { test::run(argc, argv); }