schedule.cpp 14.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
7
8
9
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
10
#include <unordered_set>
11
#include <set>
Paul's avatar
Paul committed
12
#include <deque>
Paul's avatar
Paul committed
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
17
18
19
20
21
22
23
24
25
26
auto get_inputs()
{
    return [](auto i) { return i->inputs(); };
}

auto get_outputs()
{
    return [](auto i) { return i->outputs(); };
}

27
28
29
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
30
    std::unordered_map<instruction_ref, std::size_t> weights;
31
    std::unordered_map<instruction_ref, std::size_t> iweights;
Paul's avatar
Paul committed
32
33
34
35

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
Paul's avatar
Paul committed
36
            if(not contains(weights, ins))
Paul's avatar
Paul committed
37
            {
Paul's avatar
Paul committed
38
                std::size_t weight = 0;
Paul's avatar
Paul committed
39
                auto&& op          = ins->get_operator();
Paul's avatar
Paul committed
40
41
                if(not is_context_free(op) and op.name()[0] != '@')
                    weight = model.weight(op);
42
                iweights[ins] = weight;
Paul's avatar
Paul committed
43
44
45
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
Paul's avatar
Paul committed
46
                                    weight,
Paul's avatar
Paul committed
47
48
49
50
51
52
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

Paul's avatar
Paul committed
53
54
    std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
    {
Paul's avatar
Paul committed
55
        if(args.size() < 2)
Paul's avatar
Paul committed
56
57
58
59
        {
            return args.end();
        }

Paul's avatar
Paul committed
60
61
62
63
        const std::size_t min_partition_threshold = 2;
        auto compare                              = by(std::greater<>{}, [&](auto x) {
            return std::make_tuple(this->weights[x], x->inputs().size());
        });
Paul's avatar
Paul committed
64
65
        std::sort(args.begin(), args.end(), compare);

Paul's avatar
Paul committed
66
        auto it = std::lower_bound(std::next(args.begin()),
Paul's avatar
Paul committed
67
68
69
                                   args.end(),
                                   min_partition_threshold,
                                   [&](auto i, std::size_t w) { return this->weights[i] > w; });
Paul's avatar
Paul committed
70
        assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
Paul's avatar
Paul committed
71
72
        assert(it == args.end() or std::prev(it) == args.begin() or
               this->weights[*std::prev(it)] > min_partition_threshold);
Paul's avatar
Paul committed
73
        return it;
Paul's avatar
Paul committed
74
75
    }

Paul's avatar
Paul committed
76
    struct partition
Paul's avatar
Paul committed
77
    {
Paul's avatar
Paul committed
78
79
80
81
        std::size_t weight = 0;
        std::vector<instruction_ref> instructions{};

        void add(instruction_ref ins, std::size_t w)
Paul's avatar
Paul committed
82
        {
Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
91
            weight += w;
            instructions.push_back(ins);
        }
    };

    void assign_streams(program& p, std::size_t n)
    {
        partition critical;
        std::unordered_map<instruction_ref, std::deque<partition>> partitions;
Paul's avatar
Paul committed
92
        partitions.reserve(weights.size());
Paul's avatar
Paul committed
93
        fix([&](auto self, auto ins, auto& part) {
Paul's avatar
Paul committed
94
            assert(ins != p.end());
Paul's avatar
Paul committed
95
            if(contains(partitions, ins))
Paul's avatar
Paul committed
96
                return;
Paul's avatar
Paul committed
97
98
            assert(p.has_instruction(ins));
            // Add an entry so we know the instruction was visited
Paul's avatar
Paul committed
99
            partitions[ins];
Paul's avatar
Paul committed
100
101
            part.add(ins, this->iweights[ins]);

Paul's avatar
Paul committed
102
            auto args         = ins->inputs();
Paul's avatar
Paul committed
103
            auto threshold_it = this->sort_args(args);
Paul's avatar
Paul committed
104

Paul's avatar
Paul committed
105
            if(not args.empty())
Paul's avatar
Paul committed
106
            {
Paul's avatar
Paul committed
107
108
109
                assert(threshold_it != args.begin());
                self(args.front(), part);
                for(auto i : range(std::next(args.begin()), threshold_it))
Paul's avatar
Paul committed
110
111
112
113
                {
                    partitions[ins].emplace_back();
                    self(i, partitions[ins].back());
                }
Paul's avatar
Paul committed
114
115
116
117
                for(auto i : range(threshold_it, args.end()))
                {
                    self(i, part);
                }
Paul's avatar
Paul committed
118
            }
Paul's avatar
Paul committed
119
120
            // Sort instructions
            p.move_instruction(ins, p.end());
Paul's avatar
Paul committed
121
122
123
124
        })(std::prev(p.end()), critical);

        // Set the critical partition to stream 0
        set_stream(critical, 0);
Paul's avatar
Paul committed
125
        std::vector<std::size_t> streams(n - 1);
Paul's avatar
Paul committed
126
        // Assign streams for the other partitions
Paul's avatar
Paul committed
127
        for(auto&& ins_part : partitions)
Paul's avatar
Paul committed
128
        {
Paul's avatar
Paul committed
129
130
131
132
133
            std::sort(
                ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) {
                    return std::make_tuple(x.weight, x.instructions.size());
                }));
            for(auto&& part : ins_part.second)
Paul's avatar
Paul committed
134
135
            {
                auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
Paul's avatar
Paul committed
136
                set_stream(part, stream + 1);
Paul's avatar
Paul committed
137
138
                streams[stream] += part.weight;
            }
Paul's avatar
Paul committed
139
140
        }
    }
141

Paul's avatar
Paul committed
142
143
    void set_stream(const partition& p, std::size_t n)
    {
Paul's avatar
Paul committed
144
145
        for(auto ins : p.instructions)
            if(iweights[ins] > 0)
Paul's avatar
Paul committed
146
147
148
                set_stream(ins, n);
    }

Paul's avatar
Paul committed
149
    void set_stream(instruction_ref ins, std::size_t n)
150
    {
Paul's avatar
Paul committed
151
152
        assert(iweights[ins] > 0);
        ins2stream[ins] = n;
153
    }
154

Paul's avatar
Paul committed
155
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
156

Paul's avatar
Paul committed
157
    bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
158

Paul's avatar
Paul committed
159
    template <class F>
Paul's avatar
Paul committed
160
    bool different(F f, std::size_t stream) const
Paul's avatar
Paul committed
161
    {
Paul's avatar
Paul committed
162
        bool result = false;
Paul's avatar
Paul committed
163
        f([&](auto s) {
Paul's avatar
Paul committed
164
            if(s != stream)
Paul's avatar
Paul committed
165
            {
Paul's avatar
Paul committed
166
167
                result = true;
                return false;
Paul's avatar
Paul committed
168
            }
Paul's avatar
Paul committed
169
            // cppcheck-suppress uselessAssignmentArg
Paul's avatar
Paul committed
170
171
172
            stream = s;
            return true;
        });
Paul's avatar
Paul committed
173
174
        return result;
    }
175

Paul's avatar
Paul committed
176
177
178
179
180
    template <class F>
    bool different(F f) const
    {
        bool result = false;
        f([&](auto s) {
Paul's avatar
Paul committed
181
            result = this->different(f, s);
Paul's avatar
Paul committed
182
183
184
185
186
            return false;
        });
        return result;
    }

Paul's avatar
Paul committed
187
    template <class Selector>
Paul's avatar
Paul committed
188
    auto get_streams_from(instruction_ref start, Selector select) const
Paul's avatar
Paul committed
189
190
191
192
193
    {
        return [=](auto f) {
            return fix<bool>([&](auto self, auto ins) {
                for(auto i : select(ins))
                {
194
                    if(iweights.at(i) == 0)
Paul's avatar
Paul committed
195
                    {
Paul's avatar
Paul committed
196
                        if(not self(i))
Paul's avatar
Paul committed
197
198
199
200
                            return false;
                    }
                    else
                    {
Paul's avatar
Paul committed
201
                        if(not f(this->get_stream(i)))
Paul's avatar
Paul committed
202
203
204
205
206
207
208
209
                            return false;
                    }
                }
                return true;
            })(start);
        };
    }

Paul's avatar
Paul committed
210
    std::unordered_set<std::size_t> get_streams(instruction_ref ins) const
Paul's avatar
Paul committed
211
    {
Paul's avatar
Paul committed
212
        if(has_stream(ins))
Paul's avatar
Paul committed
213
214
215
216
217
218
219
220
221
            return {get_stream(ins)};
        std::unordered_set<std::size_t> result;
        get_streams_from(ins, get_inputs())([&](auto s) {
            result.insert(s);
            return true;
        });
        return result;
    }

Paul's avatar
Paul committed
222
223
224
    template <class... Ts>
    bool is_merge_point(instruction_ref ins, Ts... xs) const
    {
Paul's avatar
Paul committed
225
        return different(get_streams_from(ins, get_inputs()), xs...);
Paul's avatar
Paul committed
226
    }
Paul's avatar
Paul committed
227

Paul's avatar
Paul committed
228
229
230
    template <class... Ts>
    bool is_split_point(instruction_ref ins, Ts... xs) const
    {
Paul's avatar
Paul committed
231
        return different(get_streams_from(ins, get_outputs()), xs...);
Paul's avatar
Paul committed
232
    }
233

Paul's avatar
Paul committed
234
235
236
237
238
239
240
241
242
243
244
245
    std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
    {
        std::vector<instruction_ref> result;
        std::unordered_map<std::size_t, instruction_ref> m;
        fix([&](auto self, auto ins) {
            for(auto i : ins->inputs())
            {
                if(iweights.at(i) == 0)
                {
                    self(i);
                    continue;
                }
Paul's avatar
Paul committed
246
                auto stream = this->get_stream(i);
Paul's avatar
Paul committed
247
                if(not contains(m, stream))
Paul's avatar
Paul committed
248
249
                    m[stream] = i;
                else
Paul's avatar
Paul committed
250
251
252
                    m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) {
                                             return std::distance(x, start);
                                         }));
Paul's avatar
Paul committed
253
254
            }
        })(start);
Paul's avatar
Paul committed
255
256
        std::transform(
            m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; });
Paul's avatar
Paul committed
257
258
259
        return result;
    }

Paul's avatar
Paul committed
260
261
    std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
    find_concurrent_instructions(program& p)
Paul's avatar
Paul committed
262
    {
Paul's avatar
Paul committed
263
        std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
Paul's avatar
Paul committed
264
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
Paul's avatar
Paul committed
265
266
        result.reserve(p.size());
        merge_from.reserve(p.size());
Paul's avatar
Paul committed
267
        for(auto ins : reverse_iterator_for(p))
Paul's avatar
Paul committed
268
        {
Paul's avatar
Paul committed
269
            for(auto&& arg : ins->outputs())
Paul's avatar
Paul committed
270
            {
Paul's avatar
Paul committed
271
272
273
                if(is_merge_point(arg))
                    merge_from[ins].insert(arg);
                merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
Paul's avatar
Paul committed
274
275
            }

Paul's avatar
Paul committed
276
            auto streams = this->get_streams(ins);
Paul's avatar
Paul committed
277
278
279

            // Collect concur instructions for each merge point.
            for(auto& merge : merge_from[ins])
Paul's avatar
Paul committed
280
            {
Paul's avatar
Paul committed
281
                for(auto stream : streams)
Paul's avatar
Paul committed
282
283
284
                {
                    if(result[merge].size() <= stream)
                        result[merge].resize(stream + 1);
Paul's avatar
Paul committed
285
286
                    auto&& r = result[merge][stream];
                    r.push_back(ins);
Paul's avatar
Paul committed
287
                    // Copy inputs if they dont have a stream(and are not a builtin and context
Paul's avatar
Paul committed
288
                    // free). Inputs without a stream can have a implicit dependency
Paul's avatar
Paul committed
289
290
291
292
293
294
295
296
                    std::copy_if(ins->inputs().begin(),
                                 ins->inputs().end(),
                                 std::back_inserter(r),
                                 [&](auto x) {
                                     return not this->has_stream(x) and
                                            not is_context_free(x->get_operator()) and
                                            x->name().front() != '@';
                                 });
Paul's avatar
Paul committed
297
                }
Paul's avatar
Paul committed
298
299
            }
        }
Paul's avatar
Paul committed
300
        return result;
Paul's avatar
Paul committed
301
    }
Paul's avatar
Paul committed
302

Paul's avatar
Paul committed
303
304
    std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
    get_conflicts(program& p)
Paul's avatar
Paul committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    {
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
        auto concur_ins = this->find_concurrent_instructions(p);
        for(auto&& merge : concur_ins)
        {
            dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
                if(i == j)
                    return;
                for(auto ins1 : merge.second[i])
                {
                    auto p1 = std::distance(ins1, merge.first);
                    for(auto ins2 : merge.second[j])
                    {
                        if(ins1 == ins2)
                            continue;
                        auto p2 = std::distance(ins2, merge.first);
                        // The smaller distance means the instruction occurs later
                        if(p1 > p2)
                            conflict_table[ins2].insert(ins1);
                        else
                            conflict_table[ins1].insert(ins2);
                    }
                }
            });
        }
        // Remove duplicates
        for(auto&& ip : conflict_table)
        {
            auto ins1 = ip.first;
            for(auto ins2 : ip.second)
                if(contains(conflict_table[ins2], ins1))
                    conflict_table[ins2].erase(ins1);
        }
        return conflict_table;
    }
340
341
};

Paul's avatar
Paul committed
342
343
void schedule::apply(program& p) const
{
344
    stream_info si;
Paul's avatar
Paul committed
345
346
347
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
348

Paul's avatar
Paul committed
349
350
351
352
353
    if(enabled(MIGRAPHX_TRACE_COMPILE{}))
    {
        p.annotate(std::cout, [&](auto ins) {
            std::cout << ":";
            std::cout << " weight=" << si.weights.at(ins);
Paul's avatar
Paul committed
354
            std::cout << " input={";
Paul's avatar
Paul committed
355
            si.get_streams_from(ins, get_inputs())([&](auto s) {
Paul's avatar
Paul committed
356
357
358
359
                std::cout << s << ",";
                return true;
            });
            std::cout << "}";
Paul's avatar
Paul committed
360
            if(si.has_stream(ins))
Paul's avatar
Paul committed
361
362
363
364
365
                std::cout << " stream=" << si.get_stream(ins);
        });
        std::cout << std::endl;
    }

366
    // Schedule instructions
Paul's avatar
Paul committed
367
    std::size_t wait_id = 0;
Paul's avatar
Paul committed
368
    std::unordered_map<instruction_ref, std::size_t> ins2wait;
Paul's avatar
Paul committed
369
370
    std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
    std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
Paul's avatar
Paul committed
371
372
    ins2wait.reserve(p.size());
    ins2waited.reserve(p.size());
Paul's avatar
Paul committed
373
    for(auto ins : iterator_for(p))
374
    {
Paul's avatar
Paul committed
375
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
376
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
377
            continue;
378
        assert(si.weights[ins] > 0);
Paul's avatar
Paul committed
379
        // Schedule instruction on the stream
Paul's avatar
Paul committed
380
        auto stream = si.get_stream(ins);
Paul's avatar
Paul committed
381
        assert(stream < model.concurrency());
Paul's avatar
Paul committed
382
383
        model.sched(p, ins, stream);
        // Insert wait instructions
Paul's avatar
Paul committed
384
        if(si.is_merge_point(ins, stream))
Paul's avatar
Paul committed
385
        {
Paul's avatar
Paul committed
386
            for(auto i : si.get_recorded_instructions(ins))
Paul's avatar
Paul committed
387
            {
Paul's avatar
Paul committed
388
                if(not si.has_stream(i))
Paul's avatar
Paul committed
389
                    continue;
Paul's avatar
Paul committed
390
391
                auto istream = si.get_stream(i);
                if(stream == istream)
Paul's avatar
Paul committed
392
393
                    continue;
                // Create a new event if it hasn't been recorded
Paul's avatar
Paul committed
394
                if(not contains(ins2wait, i))
Paul's avatar
Paul committed
395
396
397
398
399
                {
                    ins2wait[i] = wait_id;
                    model.record(p, i, wait_id);
                    wait_id++;
                }
Paul's avatar
Paul committed
400
401
402
                auto w = ins2wait.at(i);
                // If we already waited for the event on this stream then dont
                // insert another wait event
Paul's avatar
Paul committed
403
                if(not contains(waited_for[stream], w))
Paul's avatar
Paul committed
404
405
406
407
408
                    model.wait(p, ins, w);
                // Store the event as waited
                waited_for[stream].insert(w);
                // Store all wait events that have been waited on prior to the recorded instruction
                waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end());
Paul's avatar
Paul committed
409
            }
Paul's avatar
Paul committed
410
        }
Paul's avatar
Paul committed
411
412
413
414
415
        // Store wait events that have already been waited on
        if(si.is_split_point(ins, stream))
        {
            ins2waited[ins] = waited_for[stream];
        }
416
    }
Paul's avatar
Paul committed
417

Paul's avatar
Paul committed
418
    // Add memory conflicts
Paul's avatar
Paul committed
419
    auto conflict_table = si.get_conflicts(p);
Paul's avatar
Paul committed
420
    for(auto&& ip : conflict_table)
Paul's avatar
Paul committed
421
    {
Paul's avatar
Paul committed
422
        if(ip.second.empty())
Paul's avatar
Paul committed
423
424
425
426
427
428
            continue;
        std::vector<instruction_ref> args;
        args.push_back(ip.first);
        args.insert(args.end(), ip.second.begin(), ip.second.end());
        p.insert_instruction(std::next(ip.first), op::identity{}, args);
    }
Paul's avatar
Paul committed
429
430
431
432
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx