schedule_test.cpp 17.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/schedule.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <basic_ops.hpp>
#include <test.hpp>

struct unary_op
{
    std::string name() const { return "unary"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

Paul's avatar
Paul committed
31
struct nary_op
Paul's avatar
Paul committed
32
{
Paul's avatar
Paul committed
33
    std::string name() const { return "nary"; }
Paul's avatar
Paul committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
};

Paul's avatar
Paul committed
50
51
struct wait_event
{
Paul's avatar
Paul committed
52
53
    std::shared_ptr<std::vector<std::size_t>> wait_for =
        std::make_shared<std::vector<std::size_t>>();
Paul's avatar
Paul committed
54
55
56
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
57
        return migraphx::pack(f(*self.wait_for, "wait_for"));
Paul's avatar
Paul committed
58
59
60
61
    }
    std::string name() const { return "wait_event"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }

Paul's avatar
Paul committed
62
63
64
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
65
    {
Paul's avatar
Paul committed
66
67
        assert(wait_for != nullptr);
        assert(not wait_for->empty());
Paul's avatar
Paul committed
68
69
70
71
        return {};
    }
};

Paul's avatar
Paul committed
72
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
Paul's avatar
Paul committed
73
74
using wait_map =
    std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
Paul's avatar
Paul committed
75
76
77

struct schedule_model_test
{
Paul's avatar
Paul committed
78
    std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
Paul's avatar
Paul committed
79
80
    std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream =
        std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
Paul's avatar
Paul committed
81
    std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
Paul's avatar
Paul committed
82
    std::size_t concurrency() const { return 4; }
Paul's avatar
Paul committed
83
    void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
Paul's avatar
Paul committed
84
85
86
    {
        (*ins2stream)[ins] = n;
    }
Paul's avatar
Paul committed
87
    void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
Paul's avatar
Paul committed
88
    {
Paul's avatar
Paul committed
89
        if(ins2wait_for->count(ins) == 0)
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
        {
            auto event = wait_event{};
            p.insert_instruction(ins, event);
            (*ins2wait_for)[ins] = event.wait_for;
        }
        (*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
    }
    void record(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
    {
        (*wait2stream)[wait_id] = ins2stream->at(ins);
Paul's avatar
Paul committed
100
101
102
    }
    std::size_t weight(const migraphx::operation& op) const
    {
Paul's avatar
Paul committed
103
        if(op.name() == "binary" or op.name() == "unary")
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
            return 4;
        else
            return 1;
    }
};

struct schedule_target
{
Paul's avatar
Paul committed
112
    schedule_model_test model{};
Paul's avatar
Paul committed
113
114
115
    std::string name() const { return "schedule"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
    {
Paul's avatar
Paul committed
116
        return {migraphx::schedule{model}};
Paul's avatar
Paul committed
117
118
    }
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
    std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
    bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
Paul's avatar
Paul committed
123
124
125
126
};

bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
Paul's avatar
Paul committed
127
    for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
128
    {
Paul's avatar
Paul committed
129
        if(ins->name() != "identity")
Paul's avatar
Paul committed
130
            continue;
Paul's avatar
Paul committed
131
        if(not migraphx::contains(ins->inputs(), x))
Paul's avatar
Paul committed
132
            continue;
Paul's avatar
Paul committed
133
        if(not migraphx::contains(ins->inputs(), y))
Paul's avatar
Paul committed
134
135
            continue;
        return true;
Paul's avatar
Paul committed
136
137
138
139
    }
    return false;
}

Paul's avatar
Paul committed
140
void check_conflicts(migraphx::program& p,
Paul's avatar
Paul committed
141
142
                     std::vector<std::vector<migraphx::instruction_ref>> conflicts,
                     bool result = true)
Paul's avatar
Paul committed
143
144
{
    migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
145
        if(i == j)
Paul's avatar
Paul committed
146
            return;
Paul's avatar
Paul committed
147
148
        for(auto ins1 : conflicts[i])
            for(auto ins2 : conflicts[j])
Paul's avatar
Paul committed
149
                CHECK(check_conflicts(p, ins1, ins2) == result);
Paul's avatar
Paul committed
150
151
152
    });
}

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
std::vector<T> sorted(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    return x;
}

Paul's avatar
Paul committed
160
template <class T>
Paul's avatar
Paul committed
161
162
163
164
165
166
167
168
169
170
171
172
173
std::vector<T> unique(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    x.erase(std::unique(x.begin(), x.end()), x.end());
    return x;
}

std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

Paul's avatar
Paul committed
174
175
176
177
178
179
180
181
182
183
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
    wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
    auto wait_ins = std::prev(ins);
Paul's avatar
Paul committed
184
    if(wait_ins->name() != "wait_event")
Paul's avatar
Paul committed
185
        return {};
Paul's avatar
Paul committed
186
    auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
Paul's avatar
Paul committed
187
188
189
190
    std::sort(wf.begin(), wf.end());
    return wf;
}

Paul's avatar
Paul committed
191
192
193
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
Paul's avatar
Paul committed
194
195
{
    std::vector<migraphx::instruction_ref> result;
Paul's avatar
Paul committed
196
    for(std::size_t i = 0; i < n; i++)
Paul's avatar
Paul committed
197
198
199
200
201
202
    {
        result.push_back(p.add_instruction(x, input));
        input = result.back();
    }
    return result;
}
Paul's avatar
Paul committed
203
204
TEST_CASE(single_entry)
{
Paul's avatar
Paul committed
205
    schedule_target t{};
Paul's avatar
Paul committed
206
207
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
208
209
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
210
    auto binary = p.add_instruction(nary_op{}, onep1, onep2);
Paul's avatar
Paul committed
211
212
213
214
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
215
216
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
217
218
219
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
TEST_CASE(zero_record)
{
    schedule_target t{};
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.has_stream(binary));
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
239
240
TEST_CASE(zero_merge1)
{
Paul's avatar
Paul committed
241
    schedule_target t{};
Paul's avatar
Paul committed
242
243
244
245
246
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
Paul's avatar
Paul committed
247
248
249
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
250
    // No stream assignment
Paul's avatar
Paul committed
251
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
252
253
254
255
256
257
258
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge2)
{
Paul's avatar
Paul committed
259
    schedule_target t{};
Paul's avatar
Paul committed
260
261
262
263
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
264
265
266
    auto binary = p.add_instruction(migraphx::op::identity{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
Paul's avatar
Paul committed
267
268
269
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
270
    // No stream assignment
Paul's avatar
Paul committed
271
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
272
273
274
275
276
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
TEST_CASE(zero_merge3)
{
    schedule_target t{};
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto id = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
    EXPECT(get_wait_for(final) == get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge4)
{
    schedule_target t{};
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto id = p.add_instruction(migraphx::op::identity{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
    EXPECT(get_wait_for(final) == get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
323
TEST_CASE(double_entry)
Paul's avatar
Paul committed
324
{
Paul's avatar
Paul committed
325
    schedule_target t{};
Paul's avatar
Paul committed
326
    migraphx::program p;
Paul's avatar
Paul committed
327
328
329
330
    auto one    = p.add_literal(1);
    auto two    = p.add_literal(2);
    auto onep   = p.add_instruction(unary_op{}, one);
    auto twop   = p.add_instruction(unary_op{}, two);
Paul's avatar
Paul committed
331
    auto binary = p.add_instruction(nary_op{}, onep, twop);
Paul's avatar
Paul committed
332
333
334
335
336
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(not t.has_stream(two));
    EXPECT(t.get_stream(onep) != t.get_stream(twop));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
337
338
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
Paul's avatar
Paul committed
339
340
341
    // EXPECT(check_conflicts(p, onep, twop));
}

Paul's avatar
Paul committed
342
TEST_CASE(two_branches)
Paul's avatar
Paul committed
343
{
Paul's avatar
Paul committed
344
    schedule_target t{};
Paul's avatar
Paul committed
345
346
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
347
348
    auto c1     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
349
    auto binary = p.add_instruction(nary_op{}, i1, c1.back());
Paul's avatar
Paul committed
350
351
352
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 1);
Paul's avatar
Paul committed
353
    for(auto ins : c1)
Paul's avatar
Paul committed
354
355
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
356
357
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
358
359
360
    check_conflicts(p, {c1, {i1}});
}

Paul's avatar
Paul committed
361
TEST_CASE(four_branches)
Paul's avatar
Paul committed
362
{
Paul's avatar
Paul committed
363
    schedule_target t{};
Paul's avatar
Paul committed
364
365
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
366
367
368
369
    auto c1     = chain(p, 4, unary_op{}, one);
    auto c2     = chain(p, 3, unary_op{}, one);
    auto c3     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
370
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
Paul's avatar
Paul committed
371
372
373
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
374
    for(auto ins : c1)
Paul's avatar
Paul committed
375
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
376
    for(auto ins : c2)
Paul's avatar
Paul committed
377
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
378
    for(auto ins : c3)
Paul's avatar
Paul committed
379
380
        EXPECT(t.get_stream(ins) == 2);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
381
382
383
384
385
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
386
    check_conflicts(p, {c1, c2, c3, {i1}});
Paul's avatar
Paul committed
387
388
}

Paul's avatar
Paul committed
389
TEST_CASE(five_branches)
Paul's avatar
Paul committed
390
{
Paul's avatar
Paul committed
391
    schedule_target t{};
Paul's avatar
Paul committed
392
393
394
395
396
397
398
399
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto c1     = chain(p, 5, unary_op{}, one);
    auto c2     = chain(p, 4, unary_op{}, one);
    auto c3     = chain(p, 3, unary_op{}, one);
    auto c4     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
Paul's avatar
Paul committed
400
401
402
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
403
    for(auto ins : c1)
Paul's avatar
Paul committed
404
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
405
    for(auto ins : c2)
Paul's avatar
Paul committed
406
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
407
    for(auto ins : c3)
Paul's avatar
Paul committed
408
        EXPECT(t.get_stream(ins) == 2);
Paul's avatar
Paul committed
409
    for(auto ins : c4)
Paul's avatar
Paul committed
410
411
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
412
413
414
415
416
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
417
418
419
420
    check_conflicts(p, {c1, c2, c3, c4});
    check_conflicts(p, {c1, c2, c3, {i1}});
}

Paul's avatar
Paul committed
421
422
TEST_CASE(four_branches_eq)
{
Paul's avatar
Paul committed
423
    schedule_target t{};
Paul's avatar
Paul committed
424
425
426
427
428
429
430
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto onep3  = p.add_instruction(unary_op{}, one);
    auto onep4  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
Paul's avatar
Paul committed
431
432
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
433
434
435
436
437
    EXPECT(
        sorted<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
        unique<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
438
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
439
440
    EXPECT(
        get_wait_for(binary) ==
Paul's avatar
Paul committed
441
442
443
        get_wait_for(
            t.get_stream(binary),
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
444
445
446
    check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}

Paul's avatar
Paul committed
447
448
TEST_CASE(seq_merge)
{
Paul's avatar
Paul committed
449
    schedule_target t{};
Paul's avatar
Paul committed
450
    migraphx::program p;
Paul's avatar
Paul committed
451
452
453
    auto one     = p.add_literal(1);
    auto c1      = chain(p, 2, unary_op{}, one);
    auto i1      = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
454
455
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
456
457
    auto c2      = chain(p, 2, unary_op{}, binary1);
    auto i2      = p.add_instruction(unary_op{}, binary1);
Paul's avatar
Paul committed
458
459
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

Paul's avatar
Paul committed
460
461
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
462

Paul's avatar
Paul committed
463
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
464
    for(auto ins : c1)
Paul's avatar
Paul committed
465
466
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary1) == 3);
Paul's avatar
Paul committed
467
468
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
469
470
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
471
    EXPECT(t.get_stream(i2) == 3);
Paul's avatar
Paul committed
472
    for(auto ins : c2)
Paul's avatar
Paul committed
473
474
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary2) == 0);
Paul's avatar
Paul committed
475
476
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
477
478
479
480
481
    check_conflicts(p, {c2, {i2}});
}

TEST_CASE(par_merge)
{
Paul's avatar
Paul committed
482
    schedule_target t{};
Paul's avatar
Paul committed
483
    migraphx::program p;
Paul's avatar
Paul committed
484
485
486
487
    auto one     = p.add_literal(1);
    auto start1  = p.add_instruction(unary_op{}, one);
    auto c1      = chain(p, 3, unary_op{}, start1);
    auto i1      = p.add_instruction(unary_op{}, start1);
Paul's avatar
Paul committed
488
489
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
490
491
492
    auto start2  = p.add_instruction(unary_op{}, one);
    auto c2      = chain(p, 2, unary_op{}, start2);
    auto i2      = p.add_instruction(unary_op{}, start2);
Paul's avatar
Paul committed
493
494
495
496
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

    auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);

Paul's avatar
Paul committed
497
498
499
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(binary3) == 0);
Paul's avatar
Paul committed
500

Paul's avatar
Paul committed
501
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
502
    for(auto ins : c1)
Paul's avatar
Paul committed
503
504
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary1) == 0);
Paul's avatar
Paul committed
505
506
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
507
508
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
509
    EXPECT(t.get_stream(i2) == 1);
Paul's avatar
Paul committed
510
    for(auto ins : c2)
Paul's avatar
Paul committed
511
512
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary2) == 3);
Paul's avatar
Paul committed
513
514
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
515
516
517
518
    check_conflicts(p, {c2, {i2}});

    EXPECT(check_conflicts(p, binary1, binary2));
}
Paul's avatar
Paul committed
519
int main(int argc, const char* argv[]) { test::run(argc, argv); }