schedule_test.cpp 19.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/schedule.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <basic_ops.hpp>
#include <test.hpp>

struct unary_op
{
    std::string name() const { return "unary"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

Paul's avatar
Paul committed
31
struct nary_op
Paul's avatar
Paul committed
32
{
Paul's avatar
Paul committed
33
    std::string name() const { return "nary"; }
Paul's avatar
Paul committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
};

Paul's avatar
Paul committed
50
51
struct wait_event
{
Paul's avatar
Paul committed
52
53
    std::shared_ptr<std::vector<std::size_t>> wait_for =
        std::make_shared<std::vector<std::size_t>>();
Paul's avatar
Paul committed
54
55
56
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
57
        return migraphx::pack(f(*self.wait_for, "wait_for"));
Paul's avatar
Paul committed
58
59
60
61
    }
    std::string name() const { return "wait_event"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }

Paul's avatar
Paul committed
62
63
64
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
65
    {
Paul's avatar
Paul committed
66
67
        assert(wait_for != nullptr);
        assert(not wait_for->empty());
Paul's avatar
Paul committed
68
69
70
71
        return {};
    }
};

Paul's avatar
Paul committed
72
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
Paul's avatar
Paul committed
73
74
using wait_map =
    std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
Paul's avatar
Paul committed
75
76
77

struct schedule_model_test
{
Paul's avatar
Paul committed
78
    std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
Paul's avatar
Paul committed
79
80
    std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream =
        std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
Paul's avatar
Paul committed
81
    std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
Paul's avatar
Paul committed
82
    std::size_t concurrency() const { return 4; }
Paul's avatar
Paul committed
83
    void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
Paul's avatar
Paul committed
84
85
86
    {
        (*ins2stream)[ins] = n;
    }
Paul's avatar
Paul committed
87
    void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
Paul's avatar
Paul committed
88
    {
Paul's avatar
Paul committed
89
        if(ins2wait_for->count(ins) == 0)
Paul's avatar
Paul committed
90
91
92
93
94
95
96
        {
            auto event = wait_event{};
            p.insert_instruction(ins, event);
            (*ins2wait_for)[ins] = event.wait_for;
        }
        (*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
    }
Paul's avatar
Paul committed
97
    void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
Paul's avatar
Paul committed
98
99
    {
        (*wait2stream)[wait_id] = ins2stream->at(ins);
Paul's avatar
Paul committed
100
101
102
    }
    std::size_t weight(const migraphx::operation& op) const
    {
Paul's avatar
Paul committed
103
        if(op.name() == "binary" or op.name() == "unary")
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
            return 4;
        else
            return 1;
    }
};

struct schedule_target
{
Paul's avatar
Paul committed
112
    schedule_model_test model{};
Paul's avatar
Paul committed
113
114
115
    std::string name() const { return "schedule"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
    {
Paul's avatar
Paul committed
116
        return {migraphx::schedule{model}};
Paul's avatar
Paul committed
117
118
    }
    migraphx::context get_context() const { return {}; }
Paul's avatar
Paul committed
119

Paul's avatar
Paul committed
120
    std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
Paul's avatar
Paul committed
121

Paul's avatar
Paul committed
122
    bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
Paul's avatar
Paul committed
123
124
125
126
};

bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
Paul's avatar
Paul committed
127
    for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
128
    {
Paul's avatar
Paul committed
129
        if(ins->name() != "identity")
Paul's avatar
Paul committed
130
            continue;
Paul's avatar
Paul committed
131
        if(not migraphx::contains(ins->inputs(), x))
Paul's avatar
Paul committed
132
            continue;
Paul's avatar
Paul committed
133
        if(not migraphx::contains(ins->inputs(), y))
Paul's avatar
Paul committed
134
135
            continue;
        return true;
Paul's avatar
Paul committed
136
137
138
139
    }
    return false;
}

Paul's avatar
Paul committed
140
void check_conflicts(migraphx::program& p,
Paul's avatar
Paul committed
141
142
                     std::vector<std::vector<migraphx::instruction_ref>> conflicts,
                     bool result = true)
Paul's avatar
Paul committed
143
144
{
    migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
145
        if(i == j)
Paul's avatar
Paul committed
146
            return;
Paul's avatar
Paul committed
147
148
        for(auto ins1 : conflicts[i])
            for(auto ins2 : conflicts[j])
Paul's avatar
Paul committed
149
                CHECK(check_conflicts(p, ins1, ins2) == result);
Paul's avatar
Paul committed
150
151
152
    });
}

Paul's avatar
Paul committed
153
template <class T>
Paul's avatar
Paul committed
154
155
156
157
158
159
std::vector<T> sorted(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    return x;
}

Paul's avatar
Paul committed
160
template <class T>
Paul's avatar
Paul committed
161
162
163
164
165
166
167
168
169
170
171
172
173
std::vector<T> unique(std::vector<T> x)
{
    std::sort(x.begin(), x.end());
    x.erase(std::unique(x.begin(), x.end()), x.end());
    return x;
}

std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

Paul's avatar
Paul committed
174
175
176
177
178
179
180
181
182
183
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
    wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
    auto wait_ins = std::prev(ins);
Paul's avatar
Paul committed
184
185
186
    // Skip identity operators
    while(wait_ins->name() == "identity")
        wait_ins = std::prev(wait_ins);
Paul's avatar
Paul committed
187
    if(wait_ins->name() != "wait_event")
Paul's avatar
Paul committed
188
        return {};
Paul's avatar
Paul committed
189
    auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
Paul's avatar
Paul committed
190
191
192
193
    std::sort(wf.begin(), wf.end());
    return wf;
}

Paul's avatar
Paul committed
194
195
196
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
Paul's avatar
Paul committed
197
198
{
    std::vector<migraphx::instruction_ref> result;
Paul's avatar
Paul committed
199
    for(std::size_t i = 0; i < n; i++)
Paul's avatar
Paul committed
200
201
202
203
204
205
    {
        result.push_back(p.add_instruction(x, input));
        input = result.back();
    }
    return result;
}
Paul's avatar
Paul committed
206
207
TEST_CASE(single_entry)
{
Paul's avatar
Paul committed
208
    schedule_target t{};
Paul's avatar
Paul committed
209
210
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
211
212
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
213
    auto binary = p.add_instruction(nary_op{}, onep1, onep2);
Paul's avatar
Paul committed
214
215
216
217
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
218
219
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
220
221
222
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
TEST_CASE(zero_record)
{
    schedule_target t{};
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    EXPECT(t.has_stream(binary));
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
242
243
TEST_CASE(zero_merge1)
{
Paul's avatar
Paul committed
244
    schedule_target t{};
Paul's avatar
Paul committed
245
246
247
248
249
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
Paul's avatar
Paul committed
250
251
252
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
253
    // No stream assignment
Paul's avatar
Paul committed
254
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
255
256
257
258
259
260
261
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge2)
{
Paul's avatar
Paul committed
262
    schedule_target t{};
Paul's avatar
Paul committed
263
264
265
266
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
267
268
269
    auto binary = p.add_instruction(migraphx::op::identity{},
                                    p.add_instruction(migraphx::op::identity{}, onep1),
                                    p.add_instruction(migraphx::op::identity{}, onep2));
Paul's avatar
Paul committed
270
271
272
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
Paul's avatar
Paul committed
273
    // No stream assignment
Paul's avatar
Paul committed
274
    EXPECT(not t.has_stream(binary));
Paul's avatar
Paul committed
275
276
277
278
279
    // There is no wait
    EXPECT(get_wait_for(binary).empty());
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
280
281
282
283
TEST_CASE(zero_merge3)
{
    schedule_target t{};
    migraphx::program p;
Paul's avatar
Paul committed
284
285
286
287
    auto one   = p.add_literal(1);
    auto onep1 = p.add_instruction(unary_op{}, one);
    auto onep2 = p.add_instruction(unary_op{}, one);
    auto id    = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
Paul's avatar
Paul committed
288
289
290
291
292
293
294
295
296
297
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
Paul's avatar
Paul committed
298
299
    EXPECT(get_wait_for(final) ==
           get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
300
301
302
303
304
305
306
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(zero_merge4)
{
    schedule_target t{};
    migraphx::program p;
Paul's avatar
Paul committed
307
308
309
310
311
312
    auto one   = p.add_literal(1);
    auto onep1 = p.add_instruction(unary_op{}, one);
    auto onep2 = p.add_instruction(unary_op{}, one);
    auto id    = p.add_instruction(migraphx::op::identity{},
                                p.add_instruction(migraphx::op::identity{}, onep1),
                                p.add_instruction(migraphx::op::identity{}, onep2));
Paul's avatar
Paul committed
313
314
315
316
317
318
319
320
321
322
    auto final = p.add_instruction(unary_op{}, id);
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
    // No stream assignment
    EXPECT(not t.has_stream(id));
    // There is no wait
    EXPECT(get_wait_for(id).empty());
    // Stream assignment for final op
    EXPECT(t.get_stream(final) == 0);
Paul's avatar
Paul committed
323
324
    EXPECT(get_wait_for(final) ==
           get_wait_for(t.get_stream(final), {t.get_stream(onep1), t.get_stream(onep2)}));
Paul's avatar
Paul committed
325
326
327
    EXPECT(check_conflicts(p, onep1, onep2));
}

Paul's avatar
Paul committed
328
TEST_CASE(double_entry)
Paul's avatar
Paul committed
329
{
Paul's avatar
Paul committed
330
    schedule_target t{};
Paul's avatar
Paul committed
331
    migraphx::program p;
Paul's avatar
Paul committed
332
333
334
335
    auto one    = p.add_literal(1);
    auto two    = p.add_literal(2);
    auto onep   = p.add_instruction(unary_op{}, one);
    auto twop   = p.add_instruction(unary_op{}, two);
Paul's avatar
Paul committed
336
    auto binary = p.add_instruction(nary_op{}, onep, twop);
Paul's avatar
Paul committed
337
338
339
340
341
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(not t.has_stream(two));
    EXPECT(t.get_stream(onep) != t.get_stream(twop));
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
342
343
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
Paul's avatar
Paul committed
344
    EXPECT(check_conflicts(p, onep, twop));
Paul's avatar
Paul committed
345
346
}

Paul's avatar
Paul committed
347
TEST_CASE(two_branches)
Paul's avatar
Paul committed
348
{
Paul's avatar
Paul committed
349
    schedule_target t{};
Paul's avatar
Paul committed
350
351
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
352
353
    auto c1     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
354
    auto binary = p.add_instruction(nary_op{}, i1, c1.back());
Paul's avatar
Paul committed
355
356
357
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 1);
Paul's avatar
Paul committed
358
    for(auto ins : c1)
Paul's avatar
Paul committed
359
360
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
361
362
    EXPECT(get_wait_for(binary) ==
           get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
363
364
365
    check_conflicts(p, {c1, {i1}});
}

Paul's avatar
Paul committed
366
TEST_CASE(four_branches)
Paul's avatar
Paul committed
367
{
Paul's avatar
Paul committed
368
    schedule_target t{};
Paul's avatar
Paul committed
369
370
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
371
372
373
374
    auto c1     = chain(p, 4, unary_op{}, one);
    auto c2     = chain(p, 3, unary_op{}, one);
    auto c3     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
375
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
Paul's avatar
Paul committed
376
377
378
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
379
    for(auto ins : c1)
Paul's avatar
Paul committed
380
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
381
    for(auto ins : c2)
Paul's avatar
Paul committed
382
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
383
    for(auto ins : c3)
Paul's avatar
Paul committed
384
385
        EXPECT(t.get_stream(ins) == 2);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
386
387
388
389
390
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
391
    check_conflicts(p, {c1, c2, c3, {i1}});
Paul's avatar
Paul committed
392
393
}

Paul's avatar
Paul committed
394
TEST_CASE(five_branches)
Paul's avatar
Paul committed
395
{
Paul's avatar
Paul committed
396
    schedule_target t{};
Paul's avatar
Paul committed
397
398
399
400
401
402
403
404
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto c1     = chain(p, 5, unary_op{}, one);
    auto c2     = chain(p, 4, unary_op{}, one);
    auto c3     = chain(p, 3, unary_op{}, one);
    auto c4     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
Paul's avatar
Paul committed
405
406
407
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(i1) == 3);
Paul's avatar
Paul committed
408
    for(auto ins : c1)
Paul's avatar
Paul committed
409
        EXPECT(t.get_stream(ins) == 0);
Paul's avatar
Paul committed
410
    for(auto ins : c2)
Paul's avatar
Paul committed
411
        EXPECT(t.get_stream(ins) == 1);
Paul's avatar
Paul committed
412
    for(auto ins : c3)
Paul's avatar
Paul committed
413
        EXPECT(t.get_stream(ins) == 2);
Paul's avatar
Paul committed
414
    for(auto ins : c4)
Paul's avatar
Paul committed
415
416
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
417
418
419
420
421
    EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
                                                {t.get_stream(c1.back()),
                                                 t.get_stream(c2.back()),
                                                 t.get_stream(c3.back()),
                                                 t.get_stream(i1)}));
Paul's avatar
Paul committed
422
423
424
425
    check_conflicts(p, {c1, c2, c3, c4});
    check_conflicts(p, {c1, c2, c3, {i1}});
}

Paul's avatar
Paul committed
426
427
TEST_CASE(four_branches_eq)
{
Paul's avatar
Paul committed
428
    schedule_target t{};
Paul's avatar
Paul committed
429
430
431
432
433
434
435
    migraphx::program p;
    auto one    = p.add_literal(1);
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
    auto onep3  = p.add_instruction(unary_op{}, one);
    auto onep4  = p.add_instruction(unary_op{}, one);
    auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
Paul's avatar
Paul committed
436
437
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
438
439
440
441
442
    EXPECT(
        sorted<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
        unique<std::size_t>(
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
443
    EXPECT(t.get_stream(binary) == 0);
Paul's avatar
Paul committed
444
445
    EXPECT(
        get_wait_for(binary) ==
Paul's avatar
Paul committed
446
447
448
        get_wait_for(
            t.get_stream(binary),
            {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
Paul's avatar
Paul committed
449
450
451
    check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}

Paul's avatar
Paul committed
452
453
TEST_CASE(seq_merge)
{
Paul's avatar
Paul committed
454
    schedule_target t{};
Paul's avatar
Paul committed
455
    migraphx::program p;
Paul's avatar
Paul committed
456
457
458
    auto one     = p.add_literal(1);
    auto c1      = chain(p, 2, unary_op{}, one);
    auto i1      = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
459
460
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
461
462
    auto c2      = chain(p, 2, unary_op{}, binary1);
    auto i2      = p.add_instruction(unary_op{}, binary1);
Paul's avatar
Paul committed
463
464
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

Paul's avatar
Paul committed
465
466
    p.compile(t);
    EXPECT(not t.has_stream(one));
Paul's avatar
Paul committed
467

Paul's avatar
Paul committed
468
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
469
    for(auto ins : c1)
Paul's avatar
Paul committed
470
471
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary1) == 3);
Paul's avatar
Paul committed
472
473
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
474
475
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
476
    EXPECT(t.get_stream(i2) == 3);
Paul's avatar
Paul committed
477
    for(auto ins : c2)
Paul's avatar
Paul committed
478
479
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary2) == 0);
Paul's avatar
Paul committed
480
481
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
482
483
484
485
486
    check_conflicts(p, {c2, {i2}});
}

TEST_CASE(par_merge)
{
Paul's avatar
Paul committed
487
    schedule_target t{};
Paul's avatar
Paul committed
488
    migraphx::program p;
Paul's avatar
Paul committed
489
490
491
492
    auto one     = p.add_literal(1);
    auto start1  = p.add_instruction(unary_op{}, one);
    auto c1      = chain(p, 3, unary_op{}, start1);
    auto i1      = p.add_instruction(unary_op{}, start1);
Paul's avatar
Paul committed
493
494
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

Paul's avatar
Paul committed
495
496
497
    auto start2  = p.add_instruction(unary_op{}, one);
    auto c2      = chain(p, 2, unary_op{}, start2);
    auto i2      = p.add_instruction(unary_op{}, start2);
Paul's avatar
Paul committed
498
499
500
501
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

    auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);

Paul's avatar
Paul committed
502
503
504
    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(t.get_stream(binary3) == 0);
Paul's avatar
Paul committed
505

Paul's avatar
Paul committed
506
    EXPECT(t.get_stream(i1) == 2);
Paul's avatar
Paul committed
507
    for(auto ins : c1)
Paul's avatar
Paul committed
508
509
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary1) == 0);
Paul's avatar
Paul committed
510
511
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
Paul's avatar
Paul committed
512
513
    check_conflicts(p, {c1, {i1}});

Paul's avatar
Paul committed
514
    EXPECT(t.get_stream(i2) == 1);
Paul's avatar
Paul committed
515
    for(auto ins : c2)
Paul's avatar
Paul committed
516
517
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary2) == 3);
Paul's avatar
Paul committed
518
519
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
Paul's avatar
Paul committed
520
521
522
    check_conflicts(p, {c2, {i2}});

    EXPECT(check_conflicts(p, binary1, binary2));
Paul's avatar
Paul committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    check_conflicts(p, {c1, {i1}, c2, {i2}});
}

TEST_CASE(par_merge_multi_entry)
{
    schedule_target t{};
    migraphx::program p;
    auto one     = p.add_literal(1);
    auto start1  = p.add_instruction(unary_op{}, one);
    auto c1      = chain(p, 3, unary_op{}, start1);
    auto i1      = p.add_instruction(unary_op{}, start1);
    auto binary1 = p.add_instruction(nary_op{}, i1, c1.back());

    auto two     = p.add_literal(1);
    auto start2  = p.add_instruction(unary_op{}, two);
    auto c2      = chain(p, 2, unary_op{}, start2);
    auto i2      = p.add_instruction(unary_op{}, start2);
    auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());

    auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);

    p.compile(t);
    EXPECT(not t.has_stream(one));
    EXPECT(not t.has_stream(two));
    EXPECT(t.get_stream(binary3) == 0);

    EXPECT(t.get_stream(i1) == 2);
    for(auto ins : c1)
        EXPECT(t.get_stream(ins) == 0);
    EXPECT(t.get_stream(binary1) == 0);
    EXPECT(get_wait_for(binary1) ==
           get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
    check_conflicts(p, {c1, {i1}});

    EXPECT(t.get_stream(i2) == 1);
    for(auto ins : c2)
        EXPECT(t.get_stream(ins) == 3);
    EXPECT(t.get_stream(binary2) == 3);
    EXPECT(get_wait_for(binary2) ==
           get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
    check_conflicts(p, {c2, {i2}});

    EXPECT(check_conflicts(p, binary1, binary2));
    check_conflicts(p, {c1, {i1}, c2, {i2}});
Paul's avatar
Paul committed
567
}
Paul's avatar
Paul committed
568
int main(int argc, const char* argv[]) { test::run(argc, argv); }