schedule.cpp 14.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
7
8
9
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
10
#include <unordered_set>
11
#include <set>
Paul's avatar
Paul committed
12
#include <deque>
Paul's avatar
Paul committed
13
14
15
16

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
17
18
19
20
21
22
23
24
25
26
auto get_inputs()
{
    return [](auto i) { return i->inputs(); };
}

auto get_outputs()
{
    return [](auto i) { return i->outputs(); };
}

27
28
29
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
30
    std::unordered_map<instruction_ref, std::size_t> weights;
31
    std::unordered_map<instruction_ref, std::size_t> iweights;
Paul's avatar
Paul committed
32
33
34
35

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
Paul's avatar
Paul committed
36
            if(not contains(weights, ins))
Paul's avatar
Paul committed
37
            {
Paul's avatar
Paul committed
38
                std::size_t weight = 0;
Paul's avatar
Paul committed
39
                auto&& op          = ins->get_operator();
Paul's avatar
Paul committed
40
41
                if(not is_context_free(op) and op.name()[0] != '@')
                    weight = model.weight(op);
42
                iweights[ins] = weight;
Paul's avatar
Paul committed
43
44
45
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
Paul's avatar
Paul committed
46
                                    weight,
Paul's avatar
Paul committed
47
48
49
50
51
52
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

Paul's avatar
Paul committed
53
54
    std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
    {
Paul's avatar
Paul committed
55
        if(args.size() < 2)
Paul's avatar
Paul committed
56
57
58
59
        {
            return args.end();
        }

Paul's avatar
Paul committed
60
61
62
63
        const std::size_t min_partition_threshold = 2;
        auto compare                              = by(std::greater<>{}, [&](auto x) {
            return std::make_tuple(this->weights[x], x->inputs().size());
        });
Paul's avatar
Paul committed
64
65
        std::sort(args.begin(), args.end(), compare);

Paul's avatar
Paul committed
66
        auto it = std::lower_bound(std::next(args.begin()),
Paul's avatar
Paul committed
67
68
69
                                   args.end(),
                                   min_partition_threshold,
                                   [&](auto i, std::size_t w) { return this->weights[i] > w; });
Paul's avatar
Paul committed
70
        assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
Paul's avatar
Paul committed
71
72
        assert(it == args.end() or std::prev(it) == args.begin() or
               this->weights[*std::prev(it)] > min_partition_threshold);
Paul's avatar
Paul committed
73
        return it;
Paul's avatar
Paul committed
74
75
    }

Paul's avatar
Paul committed
76
    struct partition
Paul's avatar
Paul committed
77
    {
Paul's avatar
Paul committed
78
79
80
81
        std::size_t weight = 0;
        std::vector<instruction_ref> instructions{};

        void add(instruction_ref ins, std::size_t w)
Paul's avatar
Paul committed
82
        {
Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
91
            weight += w;
            instructions.push_back(ins);
        }
    };

    void assign_streams(program& p, std::size_t n)
    {
        partition critical;
        std::unordered_map<instruction_ref, std::deque<partition>> partitions;
Paul's avatar
Paul committed
92
        partitions.reserve(weights.size());
Paul's avatar
Paul committed
93
        fix([&](auto self, auto ins, auto& part) {
Paul's avatar
Paul committed
94
            assert(ins != p.end());
Paul's avatar
Paul committed
95
            if(contains(partitions, ins))
Paul's avatar
Paul committed
96
                return;
Paul's avatar
Paul committed
97
98
            assert(p.has_instruction(ins));
            // Add an entry so we know the instruction was visited
Paul's avatar
Paul committed
99
            partitions[ins];
Paul's avatar
Paul committed
100
101
            part.add(ins, this->iweights[ins]);

Paul's avatar
Paul committed
102
            auto args         = ins->inputs();
Paul's avatar
Paul committed
103
            auto threshold_it = this->sort_args(args);
Paul's avatar
Paul committed
104

Paul's avatar
Paul committed
105
            if(not args.empty())
Paul's avatar
Paul committed
106
            {
Paul's avatar
Paul committed
107
108
109
                assert(threshold_it != args.begin());
                self(args.front(), part);
                for(auto i : range(std::next(args.begin()), threshold_it))
Paul's avatar
Paul committed
110
111
112
113
                {
                    partitions[ins].emplace_back();
                    self(i, partitions[ins].back());
                }
Paul's avatar
Paul committed
114
115
116
117
                for(auto i : range(threshold_it, args.end()))
                {
                    self(i, part);
                }
Paul's avatar
Paul committed
118
            }
Paul's avatar
Paul committed
119
120
            // Sort instructions
            p.move_instruction(ins, p.end());
Paul's avatar
Paul committed
121
122
123
124
        })(std::prev(p.end()), critical);

        // Set the critical partition to stream 0
        set_stream(critical, 0);
Paul's avatar
Paul committed
125
        std::vector<std::size_t> streams(n - 1);
Paul's avatar
Paul committed
126
        // Assign streams for the other partitions
Paul's avatar
Paul committed
127
        for(auto&& ins_part : partitions)
Paul's avatar
Paul committed
128
        {
Paul's avatar
Paul committed
129
130
131
132
133
            std::sort(
                ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) {
                    return std::make_tuple(x.weight, x.instructions.size());
                }));
            for(auto&& part : ins_part.second)
Paul's avatar
Paul committed
134
135
            {
                auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
Paul's avatar
Paul committed
136
                set_stream(part, stream + 1);
Paul's avatar
Paul committed
137
138
                streams[stream] += part.weight;
            }
Paul's avatar
Paul committed
139
140
        }
    }
141

Paul's avatar
Paul committed
142
143
    void set_stream(const partition& p, std::size_t n)
    {
Paul's avatar
Paul committed
144
145
        for(auto ins : p.instructions)
            if(iweights[ins] > 0)
Paul's avatar
Paul committed
146
147
148
                set_stream(ins, n);
    }

Paul's avatar
Paul committed
149
    void set_stream(instruction_ref ins, std::size_t n)
150
    {
Paul's avatar
Paul committed
151
152
        assert(iweights[ins] > 0);
        ins2stream[ins] = n;
153
    }
154

Paul's avatar
Paul committed
155
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
156

Paul's avatar
Paul committed
157
    bool has_stream(instruction_ref ins) const { return contains(ins2stream, ins); }
158

Paul's avatar
Paul committed
159
    template <class F>
Paul's avatar
Paul committed
160
    bool different(F f, std::size_t stream) const
Paul's avatar
Paul committed
161
    {
Paul's avatar
Paul committed
162
        bool result = false;
Paul's avatar
Paul committed
163
        f([&](auto s) {
Paul's avatar
Paul committed
164
            if(s != stream)
Paul's avatar
Paul committed
165
            {
Paul's avatar
Paul committed
166
167
                result = true;
                return false;
Paul's avatar
Paul committed
168
            }
Paul's avatar
Paul committed
169
            // cppcheck-suppress uselessAssignmentArg
Paul's avatar
Paul committed
170
171
172
            stream = s;
            return true;
        });
Paul's avatar
Paul committed
173
174
        return result;
    }
175

Paul's avatar
Paul committed
176
177
178
179
180
    template <class F>
    bool different(F f) const
    {
        bool result = false;
        f([&](auto s) {
Paul's avatar
Paul committed
181
            result = this->different(f, s);
Paul's avatar
Paul committed
182
183
184
185
186
            return false;
        });
        return result;
    }

Paul's avatar
Paul committed
187
    template <class Selector>
Paul's avatar
Paul committed
188
    auto get_streams_from(instruction_ref start, Selector select) const
Paul's avatar
Paul committed
189
190
191
192
193
    {
        return [=](auto f) {
            return fix<bool>([&](auto self, auto ins) {
                for(auto i : select(ins))
                {
194
                    if(iweights.at(i) == 0)
Paul's avatar
Paul committed
195
                    {
Paul's avatar
Paul committed
196
                        if(not self(i))
Paul's avatar
Paul committed
197
198
199
200
                            return false;
                    }
                    else
                    {
Paul's avatar
Paul committed
201
                        if(not f(this->get_stream(i)))
Paul's avatar
Paul committed
202
203
204
205
206
207
208
209
                            return false;
                    }
                }
                return true;
            })(start);
        };
    }

Paul's avatar
Paul committed
210
    std::unordered_set<std::size_t> get_streams(instruction_ref ins) const
Paul's avatar
Paul committed
211
    {
Paul's avatar
Paul committed
212
        if(has_stream(ins))
Paul's avatar
Paul committed
213
214
215
216
217
218
219
220
221
            return {get_stream(ins)};
        std::unordered_set<std::size_t> result;
        get_streams_from(ins, get_inputs())([&](auto s) {
            result.insert(s);
            return true;
        });
        return result;
    }

Paul's avatar
Paul committed
222
223
224
    template <class... Ts>
    bool is_merge_point(instruction_ref ins, Ts... xs) const
    {
Paul's avatar
Paul committed
225
        return different(get_streams_from(ins, get_inputs()), xs...);
Paul's avatar
Paul committed
226
    }
Paul's avatar
Paul committed
227

Paul's avatar
Paul committed
228
229
230
    template <class... Ts>
    bool is_split_point(instruction_ref ins, Ts... xs) const
    {
Paul's avatar
Paul committed
231
        return different(get_streams_from(ins, get_outputs()), xs...);
Paul's avatar
Paul committed
232
    }
233

Paul's avatar
Paul committed
234
235
236
237
238
239
240
241
242
243
244
245
    std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
    {
        std::vector<instruction_ref> result;
        std::unordered_map<std::size_t, instruction_ref> m;
        fix([&](auto self, auto ins) {
            for(auto i : ins->inputs())
            {
                if(iweights.at(i) == 0)
                {
                    self(i);
                    continue;
                }
Paul's avatar
Paul committed
246
                auto stream = this->get_stream(i);
Paul's avatar
Paul committed
247
                if(not contains(m, stream))
Paul's avatar
Paul committed
248
249
                    m[stream] = i;
                else
Paul's avatar
Paul committed
250
251
252
                    m[stream] = std::min(m[stream], i, by(std::less<>{}, [&](auto x) {
                                             return std::distance(x, start);
                                         }));
Paul's avatar
Paul committed
253
254
            }
        })(start);
Paul's avatar
Paul committed
255
256
        std::transform(
            m.begin(), m.end(), std::back_inserter(result), [](auto&& p) { return p.second; });
Paul's avatar
Paul committed
257
258
259
        return result;
    }

Paul's avatar
Paul committed
260
261
    std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
    find_concurrent_instructions(program& p)
Paul's avatar
Paul committed
262
    {
Paul's avatar
Paul committed
263
        std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
Paul's avatar
Paul committed
264
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
Paul's avatar
Paul committed
265
266
        result.reserve(p.size());
        merge_from.reserve(p.size());
Paul's avatar
Paul committed
267
        for(auto ins : reverse_iterator_for(p))
Paul's avatar
Paul committed
268
        {
Paul's avatar
Paul committed
269
            for(auto&& arg : ins->outputs())
Paul's avatar
Paul committed
270
            {
Paul's avatar
Paul committed
271
272
273
                if(is_merge_point(arg))
                    merge_from[ins].insert(arg);
                merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
Paul's avatar
Paul committed
274
275
            }

Paul's avatar
Paul committed
276
            auto streams = this->get_streams(ins);
Paul's avatar
Paul committed
277
278
279

            // Collect concur instructions for each merge point.
            for(auto& merge : merge_from[ins])
Paul's avatar
Paul committed
280
            {
Paul's avatar
Paul committed
281
                for(auto stream : streams)
Paul's avatar
Paul committed
282
283
284
                {
                    if(result[merge].size() <= stream)
                        result[merge].resize(stream + 1);
Paul's avatar
Paul committed
285
286
                    auto&& r = result[merge][stream];
                    r.push_back(ins);
Paul's avatar
Paul committed
287
                    // Copy inputs if they dont have a stream(and are not a builtin and context
Paul's avatar
Paul committed
288
                    // free). Inputs without a stream can have a implicit dependency
Paul's avatar
Paul committed
289
290
291
292
293
294
295
296
                    std::copy_if(ins->inputs().begin(),
                                 ins->inputs().end(),
                                 std::back_inserter(r),
                                 [&](auto x) {
                                     return not this->has_stream(x) and
                                            not is_context_free(x->get_operator()) and
                                            x->name().front() != '@';
                                 });
Paul's avatar
Paul committed
297
                }
Paul's avatar
Paul committed
298
299
            }
        }
Paul's avatar
Paul committed
300
        return result;
Paul's avatar
Paul committed
301
    }
302
303
};

Paul's avatar
Paul committed
304
305
void schedule::apply(program& p) const
{
306
    stream_info si;
Paul's avatar
Paul committed
307
308
309
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
310

Paul's avatar
Paul committed
311
312
313
314
315
    if(enabled(MIGRAPHX_TRACE_COMPILE{}))
    {
        p.annotate(std::cout, [&](auto ins) {
            std::cout << ":";
            std::cout << " weight=" << si.weights.at(ins);
Paul's avatar
Paul committed
316
            std::cout << " input={";
Paul's avatar
Paul committed
317
            si.get_streams_from(ins, get_inputs())([&](auto s) {
Paul's avatar
Paul committed
318
319
320
321
                std::cout << s << ",";
                return true;
            });
            std::cout << "}";
Paul's avatar
Paul committed
322
            if(si.has_stream(ins))
Paul's avatar
Paul committed
323
324
325
326
327
                std::cout << " stream=" << si.get_stream(ins);
        });
        std::cout << std::endl;
    }

328
    // Schedule instructions
Paul's avatar
Paul committed
329
    std::size_t wait_id = 0;
Paul's avatar
Paul committed
330
    std::unordered_map<instruction_ref, std::size_t> ins2wait;
Paul's avatar
Paul committed
331
332
    std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
    std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
Paul's avatar
Paul committed
333
334
    ins2wait.reserve(p.size());
    ins2waited.reserve(p.size());
Paul's avatar
Paul committed
335
    for(auto ins : iterator_for(p))
336
    {
Paul's avatar
Paul committed
337
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
338
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
339
            continue;
340
        assert(si.weights[ins] > 0);
Paul's avatar
Paul committed
341
        // Schedule instruction on the stream
Paul's avatar
Paul committed
342
        auto stream = si.get_stream(ins);
Paul's avatar
Paul committed
343
        assert(stream < model.concurrency());
Paul's avatar
Paul committed
344
345
        model.sched(p, ins, stream);
        // Insert wait instructions
Paul's avatar
Paul committed
346
        if(si.is_merge_point(ins, stream))
Paul's avatar
Paul committed
347
        {
Paul's avatar
Paul committed
348
            for(auto i : si.get_recorded_instructions(ins))
Paul's avatar
Paul committed
349
            {
Paul's avatar
Paul committed
350
                if(not si.has_stream(i))
Paul's avatar
Paul committed
351
                    continue;
Paul's avatar
Paul committed
352
353
                auto istream = si.get_stream(i);
                if(stream == istream)
Paul's avatar
Paul committed
354
355
                    continue;
                // Create a new event if it hasn't been recorded
Paul's avatar
Paul committed
356
                if(not contains(ins2wait, i))
Paul's avatar
Paul committed
357
358
359
360
361
                {
                    ins2wait[i] = wait_id;
                    model.record(p, i, wait_id);
                    wait_id++;
                }
Paul's avatar
Paul committed
362
363
364
                auto w = ins2wait.at(i);
                // If we already waited for the event on this stream then dont
                // insert another wait event
Paul's avatar
Paul committed
365
                if(not contains(waited_for[stream], w))
Paul's avatar
Paul committed
366
367
368
369
370
                    model.wait(p, ins, w);
                // Store the event as waited
                waited_for[stream].insert(w);
                // Store all wait events that have been waited on prior to the recorded instruction
                waited_for[stream].insert(ins2waited[i].begin(), ins2waited[i].end());
Paul's avatar
Paul committed
371
            }
Paul's avatar
Paul committed
372
        }
Paul's avatar
Paul committed
373
374
375
376
377
        // Store wait events that have already been waited on
        if(si.is_split_point(ins, stream))
        {
            ins2waited[ins] = waited_for[stream];
        }
378
    }
Paul's avatar
Paul committed
379

Paul's avatar
Paul committed
380
381
    // Add memory conflicts
    auto concur_ins = si.find_concurrent_instructions(p);
Paul's avatar
Paul committed
382
    std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
Paul's avatar
Paul committed
383
    for(auto&& merge : concur_ins)
Paul's avatar
Paul committed
384
    {
Paul's avatar
Paul committed
385
        dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
386
            if(i == j)
Paul's avatar
Paul committed
387
                return;
Paul's avatar
Paul committed
388
            for(auto ins1 : merge.second[i])
Paul's avatar
Paul committed
389
            {
Paul's avatar
Paul committed
390
391
392
                auto p1 = std::distance(ins1, merge.first);
                for(auto ins2 : merge.second[j])
                {
Paul's avatar
Paul committed
393
                    if(ins1 == ins2)
Paul's avatar
Paul committed
394
395
396
                        continue;
                    auto p2 = std::distance(ins2, merge.first);
                    // The smaller distance means the instruction occurs later
Paul's avatar
Paul committed
397
                    if(p1 > p2)
Paul's avatar
Paul committed
398
399
400
401
                        conflict_table[ins2].insert(ins1);
                    else
                        conflict_table[ins1].insert(ins2);
                }
Paul's avatar
Paul committed
402
403
404
            }
        });
    }
Paul's avatar
Paul committed
405
    // Remove duplicates
Paul's avatar
Paul committed
406
    for(auto&& ip : conflict_table)
Paul's avatar
Paul committed
407
408
    {
        auto ins1 = ip.first;
Paul's avatar
Paul committed
409
410
        for(auto ins2 : ip.second)
            if(contains(conflict_table[ins2], ins1))
Paul's avatar
Paul committed
411
412
                conflict_table[ins2].erase(ins1);
    }
Paul's avatar
Paul committed
413
    for(auto&& ip : conflict_table)
Paul's avatar
Paul committed
414
    {
Paul's avatar
Paul committed
415
        if(ip.second.empty())
Paul's avatar
Paul committed
416
417
418
419
420
421
            continue;
        std::vector<instruction_ref> args;
        args.push_back(ip.first);
        args.insert(args.end(), ip.second.begin(), ip.second.end());
        p.insert_instruction(std::next(ip.first), op::identity{}, args);
    }
Paul's avatar
Paul committed
422
423
424
425
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx