schedule.cpp 9.04 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
7
8
9
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
10
#include <unordered_set>
11
#include <set>
Paul's avatar
Paul committed
12
13
14
15

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
16
17
18
19
20
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
auto get_inputs()
{
    return [](auto i) { return i->inputs(); };
}

auto get_outputs()
{
    return [](auto i) { return i->outputs(); };
}

31
32
33
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
34
35
36
37
38
39
40
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
Paul's avatar
Paul committed
41
                std::size_t weight = 0;
Paul's avatar
Paul committed
42
                auto&& op          = ins->get_operator();
Paul's avatar
Paul committed
43
44
                if(not is_context_free(op) and op.name()[0] != '@')
                    weight = model.weight(op);
Paul's avatar
Paul committed
45
46
47
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
Paul's avatar
Paul committed
48
                                    weight,
Paul's avatar
Paul committed
49
50
51
52
53
54
55
56
57
58
59
60
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
61
                // If weight is zero then stop
Paul's avatar
Paul committed
62
                if(this->weights[ins] == 0)
Paul's avatar
Paul committed
63
64
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
65
66
                if(not this->has_stream(ins))
                    this->set_stream(ins, stream);
Paul's avatar
Paul committed
67
68
69
70
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
Paul's avatar
Paul committed
71
                    const auto weight = this->weights[i];
Paul's avatar
Paul committed
72
                    // Skip instruction that already have stream assignment or too low of weights
Paul's avatar
Paul committed
73
                    if(this->has_stream(i) or weight <= min_partition_threshold)
Paul's avatar
Paul committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
93
            if(weights[ins] == 0)
Paul's avatar
Paul committed
94
                continue;
Paul's avatar
Paul committed
95
96
97
            set_stream(ins, streams - 1);
        }
    }
98

Paul's avatar
Paul committed
99
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
100

Paul's avatar
Paul committed
101
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
102

Paul's avatar
Paul committed
103
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
104

Paul's avatar
Paul committed
105
    bool different(const std::vector<std::size_t>& v) const
106
    {
Paul's avatar
Paul committed
107
        if(v.size() < 2)
108
            return false;
Paul's avatar
Paul committed
109
        return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
110
111
    }

Paul's avatar
Paul committed
112
    template <class F>
Paul's avatar
Paul committed
113
    bool different(F f, std::size_t stream) const
Paul's avatar
Paul committed
114
    {
Paul's avatar
Paul committed
115
        bool result = false;
Paul's avatar
Paul committed
116
        f([&](auto s) {
Paul's avatar
Paul committed
117
            if(s != stream)
Paul's avatar
Paul committed
118
            {
Paul's avatar
Paul committed
119
120
                result = true;
                return false;
Paul's avatar
Paul committed
121
            }
Paul's avatar
Paul committed
122
123
124
            stream = s;
            return true;
        });
Paul's avatar
Paul committed
125
126
        return result;
    }
127

Paul's avatar
Paul committed
128
129
130
131
132
133
134
135
136
137
138
    template <class F>
    bool different(F f) const
    {
        bool result = false;
        f([&](auto s) {
            result = different(f, s);
            return false;
        });
        return result;
    }

Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
    template <class Selector>
    auto get_streams(instruction_ref start, Selector select) const
    {
        return [=](auto f) {
            return fix<bool>([&](auto self, auto ins) {
                for(auto i : select(ins))
                {
                    if(weights.at(i) == 0)
                    {
Paul's avatar
Paul committed
148
                        if(not self(i))
Paul's avatar
Paul committed
149
150
151
152
                            return false;
                    }
                    else
                    {
Paul's avatar
Paul committed
153
                        if(not f(get_stream(i)))
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
                            return false;
                    }
                }
                return true;
            })(start);
        };
    }

Paul's avatar
Paul committed
162
163
    template<class... Ts>
    bool is_merge_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_inputs()), xs...); }
Paul's avatar
Paul committed
164

Paul's avatar
Paul committed
165
166
    template<class... Ts>
    bool is_split_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_outputs()), xs...); }
167
168
169

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
170
        std::vector<std::size_t> result;
Paul's avatar
Paul committed
171
        get_streams(ins, get_inputs())([&](auto s) {
Paul's avatar
Paul committed
172
173
174
175
            result.push_back(s);
            return true;
        });
        // Remove duplicates
Paul's avatar
Paul committed
176
177
        std::sort(result.begin(), result.end());
        result.erase(std::unique(result.begin(), result.end()), result.end());
Paul's avatar
Paul committed
178
        // Remove the merged stream
Paul's avatar
Paul committed
179
180
181
        auto it = std::find(result.begin(), result.end(), get_stream(ins));
        if (it != result.end())
            result.erase(it);
Paul's avatar
Paul committed
182
        return result;
183
    }
Paul's avatar
Paul committed
184

Paul's avatar
Paul committed
185
186
    std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
    find_concurrent_instructions(program& p)
Paul's avatar
Paul committed
187
    {
Paul's avatar
Paul committed
188
        std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
Paul's avatar
Paul committed
189
190
191
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
        for(auto ins : iterator_for(p))
        {
Paul's avatar
Paul committed
192
            if(weights[ins] == 0)
Paul's avatar
Paul committed
193
194
195
                continue;
            for(auto&& arg : ins->inputs())
            {
Paul's avatar
Paul committed
196
                if(is_split_point(arg))
Paul's avatar
Paul committed
197
198
199
200
                    split_from[ins].insert(arg);
                split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
            }

Paul's avatar
Paul committed
201
            auto stream = get_stream(ins);
Paul's avatar
Paul committed
202
203
204
205
206
207
208
209
210
211
            // if (is_merge_point(ins))
            // {
            //     // post-dominator kills split point.
            //     for(auto& split : split_from[ins])
            //     {
            //         if(strictly_post_dominates(ins, split))
            //             split_from[ins].erase(split);
            //     }
            // }

Paul's avatar
Paul committed
212
213
214
            // Collect concur instructions for each split point.
            for(auto& split : split_from[ins])
            {
Paul's avatar
Paul committed
215
216
                if(result[split].size() <= stream)
                    result[split].resize(stream + 1);
Paul's avatar
Paul committed
217
                result[split][stream].push_back(ins);
Paul's avatar
Paul committed
218
219
            }
        }
Paul's avatar
Paul committed
220
        return result;
Paul's avatar
Paul committed
221
    }
222
223
};

Paul's avatar
Paul committed
224
225
void schedule::apply(program& p) const
{
226
    stream_info si;
Paul's avatar
Paul committed
227
228
229
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
230

Paul's avatar
Paul committed
231
232
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
233
234
235
236
237
        auto args = ins->inputs();
        std::sort(args.begin(), args.end(), [&](auto x, auto y) {
            return si.weights[x] < si.weights[y];
        });
        for(auto i : args)
Paul's avatar
Paul committed
238
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
239
        for(auto i : args)
Paul's avatar
Paul committed
240
241
            self(i);
    })(last);
Paul's avatar
Paul committed
242

Paul's avatar
Paul committed
243
244
245
246
247
    if(enabled(MIGRAPHX_TRACE_COMPILE{}))
    {
        p.annotate(std::cout, [&](auto ins) {
            std::cout << ":";
            std::cout << " weight=" << si.weights.at(ins);
Paul's avatar
Paul committed
248
            if(si.has_stream(ins))
Paul's avatar
Paul committed
249
250
251
252
253
                std::cout << " stream=" << si.get_stream(ins);
        });
        std::cout << std::endl;
    }

254
    // Schedule instructions
Paul's avatar
Paul committed
255
    for(auto ins : iterator_for(p))
256
    {
Paul's avatar
Paul committed
257
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
258
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
259
            continue;
Paul's avatar
Paul committed
260
261
262
        auto stream = si.get_stream(ins);
        if(si.is_merge_point(ins, stream))
            model.wait(p, ins, stream, si.wait_for(ins));
Paul's avatar
Paul committed
263
        else
Paul's avatar
Paul committed
264
            model.schedule_instruction(p, ins, stream);
265
    }
Paul's avatar
Paul committed
266

Paul's avatar
Paul committed
267
268
    // Add memory conflicts
    auto concur_ins = si.find_concurrent_instructions(p);
Paul's avatar
Paul committed
269
    for(auto&& split : concur_ins)
Paul's avatar
Paul committed
270
271
    {
        dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
272
            if(i == j)
Paul's avatar
Paul committed
273
                return;
Paul's avatar
Paul committed
274
275
            for(auto ins1 : split.second[i])
            {
Paul's avatar
Paul committed
276
                auto args = split.second[j];
Paul's avatar
Paul committed
277
                args.insert(args.begin(), ins1);
Paul's avatar
Paul committed
278
279
280
281
282

                auto point = std::max_element(args.begin(), args.end(), [&](auto x, auto y) {
                    return std::distance(split.first, x) < std::distance(split.first, y);
                });
                p.insert_instruction(std::next(*point), op::identity{}, args);
Paul's avatar
Paul committed
283
284
285
            }
        });
    }
Paul's avatar
Paul committed
286
287
288
289
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx