schedule.cpp 8.72 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
7
8
9
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
10
#include <unordered_set>
11
#include <set>
Paul's avatar
Paul committed
12
13
14
15

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
16
17
18
19
20
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

21
22
23
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
24
25
26
27
28
29
30
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
Paul's avatar
Paul committed
31
                std::size_t weight = 0;
Paul's avatar
Paul committed
32
                auto&& op          = ins->get_operator();
Paul's avatar
Paul committed
33
34
                if(not is_context_free(op) and op.name()[0] != '@')
                    weight = model.weight(op);
Paul's avatar
Paul committed
35
36
37
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
Paul's avatar
Paul committed
38
                                    weight,
Paul's avatar
Paul committed
39
40
41
42
43
44
45
46
47
48
49
50
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
51
                // If weight is zero then stop
Paul's avatar
Paul committed
52
                if(this->weights[ins] == 0)
Paul's avatar
Paul committed
53
54
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
55
56
                if(not this->has_stream(ins))
                    this->set_stream(ins, stream);
Paul's avatar
Paul committed
57
58
59
60
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
Paul's avatar
Paul committed
61
                    const auto weight = this->weights[i];
Paul's avatar
Paul committed
62
                    // Skip instruction that already have stream assignment or too low of weights
Paul's avatar
Paul committed
63
                    if(this->has_stream(i) or weight <= min_partition_threshold)
Paul's avatar
Paul committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
83
            if(weights[ins] == 0)
Paul's avatar
Paul committed
84
                continue;
Paul's avatar
Paul committed
85
86
87
            set_stream(ins, streams - 1);
        }
    }
88

Paul's avatar
Paul committed
89
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
90

Paul's avatar
Paul committed
91
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
92

Paul's avatar
Paul committed
93
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
94

Paul's avatar
Paul committed
95
    bool different(const std::vector<std::size_t>& v) const
96
    {
Paul's avatar
Paul committed
97
        if(v.size() < 2)
98
            return false;
Paul's avatar
Paul committed
99
        return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
100
101
    }

Paul's avatar
Paul committed
102
    template <class F>
Paul's avatar
Paul committed
103
    bool different(F f) const
Paul's avatar
Paul committed
104
    {
Paul's avatar
Paul committed
105
        bool first         = true;
Paul's avatar
Paul committed
106
        std::size_t stream = 0;
Paul's avatar
Paul committed
107
        bool result        = false;
Paul's avatar
Paul committed
108
        f([&](auto s) {
Paul's avatar
Paul committed
109
            if(not first and s != stream)
Paul's avatar
Paul committed
110
            {
Paul's avatar
Paul committed
111
112
                result = true;
                return false;
Paul's avatar
Paul committed
113
            }
Paul's avatar
Paul committed
114
            stream = s;
Paul's avatar
Paul committed
115
            first  = false;
Paul's avatar
Paul committed
116
117
            return true;
        });
Paul's avatar
Paul committed
118
119
        return result;
    }
120

Paul's avatar
Paul committed
121
122
123
124
125
126
127
128
129
    template <class Selector>
    auto get_streams(instruction_ref start, Selector select) const
    {
        return [=](auto f) {
            return fix<bool>([&](auto self, auto ins) {
                for(auto i : select(ins))
                {
                    if(weights.at(i) == 0)
                    {
Paul's avatar
Paul committed
130
                        if(not self(i))
Paul's avatar
Paul committed
131
132
133
134
                            return false;
                    }
                    else
                    {
Paul's avatar
Paul committed
135
                        if(not f(get_stream(i)))
Paul's avatar
Paul committed
136
137
138
139
140
141
142
143
144
                            return false;
                    }
                }
                return true;
            })(start);
        };
    }

    auto get_input_streams(instruction_ref ins) const
Paul's avatar
Paul committed
145
    {
Paul's avatar
Paul committed
146
        return get_streams(ins, [](auto i) { return i->inputs(); });
Paul's avatar
Paul committed
147
148
    }

Paul's avatar
Paul committed
149
    auto get_output_streams(instruction_ref ins) const
Paul's avatar
Paul committed
150
    {
Paul's avatar
Paul committed
151
        return get_streams(ins, [](auto i) { return i->outputs(); });
Paul's avatar
Paul committed
152
153
    }

Paul's avatar
Paul committed
154
    bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
Paul's avatar
Paul committed
155

Paul's avatar
Paul committed
156
    bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); }
157
158
159

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
160
161
162
163
164
165
        std::vector<std::size_t> result;
        get_input_streams(ins)([&](auto s) {
            result.push_back(s);
            return true;
        });
        // Remove duplicates
Paul's avatar
Paul committed
166
167
        std::sort(result.begin(), result.end());
        result.erase(std::unique(result.begin(), result.end()), result.end());
Paul's avatar
Paul committed
168
169
        // Remove the merged stream
        result.erase(std::find(result.begin(), result.end(), get_stream(ins)));
Paul's avatar
Paul committed
170
        return result;
171
    }
Paul's avatar
Paul committed
172

Paul's avatar
Paul committed
173
174
    std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
    find_concurrent_instructions(program& p)
Paul's avatar
Paul committed
175
    {
Paul's avatar
Paul committed
176
        std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
Paul's avatar
Paul committed
177
178
179
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
        for(auto ins : iterator_for(p))
        {
Paul's avatar
Paul committed
180
            if(weights[ins] == 0)
Paul's avatar
Paul committed
181
182
183
                continue;
            for(auto&& arg : ins->inputs())
            {
Paul's avatar
Paul committed
184
                if(is_split_point(arg))
Paul's avatar
Paul committed
185
186
187
188
                    split_from[ins].insert(arg);
                split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
            }

Paul's avatar
Paul committed
189
190
191
192
193
194
195
196
197
198
            // if (is_merge_point(ins))
            // {
            //     // post-dominator kills split point.
            //     for(auto& split : split_from[ins])
            //     {
            //         if(strictly_post_dominates(ins, split))
            //             split_from[ins].erase(split);
            //     }
            // }

Paul's avatar
Paul committed
199
200
201
            // Collect concur instructions for each split point.
            for(auto& split : split_from[ins])
            {
Paul's avatar
Paul committed
202
                auto stream = get_stream(ins);
Paul's avatar
Paul committed
203
204
                if(result[split].size() <= stream)
                    result[split].resize(stream + 1);
Paul's avatar
Paul committed
205
                result[split][stream].push_back(ins);
Paul's avatar
Paul committed
206
207
            }
        }
Paul's avatar
Paul committed
208
        return result;
Paul's avatar
Paul committed
209
    }
210
211
};

Paul's avatar
Paul committed
212
213
void schedule::apply(program& p) const
{
214
    stream_info si;
Paul's avatar
Paul committed
215
216
217
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
218

Paul's avatar
Paul committed
219
220
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
221
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
222
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
223
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
224
225
            self(i);
    })(last);
Paul's avatar
Paul committed
226

Paul's avatar
Paul committed
227
228
229
230
231
    if(enabled(MIGRAPHX_TRACE_COMPILE{}))
    {
        p.annotate(std::cout, [&](auto ins) {
            std::cout << ":";
            std::cout << " weight=" << si.weights.at(ins);
Paul's avatar
Paul committed
232
            if(si.has_stream(ins))
Paul's avatar
Paul committed
233
234
235
236
237
                std::cout << " stream=" << si.get_stream(ins);
        });
        std::cout << std::endl;
    }

238
    // Schedule instructions
Paul's avatar
Paul committed
239
    for(auto ins : iterator_for(p))
240
    {
Paul's avatar
Paul committed
241
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
242
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
243
            continue;
Paul's avatar
Paul committed
244
        if(si.is_merge_point(ins))
245
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
Paul's avatar
Paul committed
246
247
        else
            model.schedule_instruction(p, ins, si.get_stream(ins));
248
    }
Paul's avatar
Paul committed
249

Paul's avatar
Paul committed
250
251
    // Add memory conflicts
    auto concur_ins = si.find_concurrent_instructions(p);
Paul's avatar
Paul committed
252
    for(auto&& split : concur_ins)
Paul's avatar
Paul committed
253
254
    {
        dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
255
            if(i == j)
Paul's avatar
Paul committed
256
                return;
Paul's avatar
Paul committed
257
258
            for(auto ins1 : split.second[i])
            {
Paul's avatar
Paul committed
259
                auto args = split.second[j];
Paul's avatar
Paul committed
260
                args.insert(args.begin(), ins1);
Paul's avatar
Paul committed
261
262
263
264
265

                auto point = std::max_element(args.begin(), args.end(), [&](auto x, auto y) {
                    return std::distance(split.first, x) < std::distance(split.first, y);
                });
                p.insert_instruction(std::next(*point), op::identity{}, args);
Paul's avatar
Paul committed
266
267
268
            }
        });
    }
Paul's avatar
Paul committed
269
270
271
272
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx