schedule.cpp 6.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
9
#include <unordered_set>
10
#include <set>
Paul's avatar
Paul committed
11
12
13
14

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
15
16
17
18
19
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

20
21
22
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
                                    model.weight(ins->get_operator()),
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
46
                // If weight is zero then stop
Paul's avatar
Paul committed
47
                if(weights[ins] == 0)
Paul's avatar
Paul committed
48
49
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                if(not has_stream(ins))
                    set_stream(ins, stream);
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
                    const auto weight = weights[i];
                    // Skip instruction that already have stream assignment or too low of weights
                    if(has_stream(i) or weight <= min_partition_threshold)
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
78
            if(weights[ins] == 0)
Paul's avatar
Paul committed
79
                continue;
Paul's avatar
Paul committed
80
81
82
            set_stream(ins, streams - 1);
        }
    }
83

Paul's avatar
Paul committed
84
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
85

Paul's avatar
Paul committed
86
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
87

Paul's avatar
Paul committed
88
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
89

Paul's avatar
Paul committed
90
    bool different(const std::vector<std::size_t>& v) const
91
    {
Paul's avatar
Paul committed
92
        if(v.size() < 2)
93
            return false;
Paul's avatar
Paul committed
94
        return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
95
96
    }

Paul's avatar
Paul committed
97
98
    template<class Selector>
    std::vector<std::size_t> get_streams(instruction_ref ins, Selector select) const
Paul's avatar
Paul committed
99
100
    {
        std::vector<std::size_t> result;
Paul's avatar
Paul committed
101
        for(auto i : select(ins))
Paul's avatar
Paul committed
102
        {
Paul's avatar
Paul committed
103
            if(weights.at(i) == 0)
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
112
113
114
            {
                auto vv = get_input_streams(i);
                result.insert(result.end(), vv.begin(), vv.end());
            }
            else
            {
                result.emplace_back(get_stream(i));
            }
        }
        return result;
    }
115

Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    std::vector<std::size_t> get_input_streams(instruction_ref ins) const
    {
        return get_streams(ins, [](auto i) {
            return i->inputs();
        });
    }

    std::vector<std::size_t> get_output_streams(instruction_ref ins) const
    {
        return get_streams(ins, [](auto i) {
            return i->outputs();
        });
    }

Paul's avatar
Paul committed
130
    bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
Paul's avatar
Paul committed
131
132
    
    bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); }
133
134
135

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
136
        std::vector<std::size_t> result = get_input_streams(ins);
Paul's avatar
Paul committed
137
138
        std::sort(result.begin(), result.end());
        result.erase(std::unique(result.begin(), result.end()), result.end());
Paul's avatar
Paul committed
139
        return result;
140
    }
Paul's avatar
Paul committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    template<class F>
    void find_concurrent_instructions(program& p, F f)
    {
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
        for(auto ins : iterator_for(p))
        {
            if (weights[ins] == 0)
                continue;
            for(auto&& arg : ins->inputs())
            {
                if (is_split_point(arg))
                    split_from[ins].insert(arg);
                split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
            }

            // Collect concur instructions for each split point.
            for(auto& split : split_from[ins])
            {
                f(ins, split);
            }
        }
    }
164
165
};

Paul's avatar
Paul committed
166
167
void schedule::apply(program& p) const
{
168
    stream_info si;
Paul's avatar
Paul committed
169
170
171
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
172

Paul's avatar
Paul committed
173
174
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
175
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
176
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
177
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
178
179
            self(i);
    })(last);
Paul's avatar
Paul committed
180

181
    // Schedule instructions
Paul's avatar
Paul committed
182
    for(auto ins : iterator_for(p))
183
    {
Paul's avatar
Paul committed
184
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
185
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
186
            continue;
Paul's avatar
Paul committed
187
        if(si.is_merge_point(ins))
188
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
Paul's avatar
Paul committed
189
190
        else
            model.schedule_instruction(p, ins, si.get_stream(ins));
191
    }
Paul's avatar
Paul committed
192
193
194
195

    si.find_concurrent_instructions(p, [&](auto x, auto y) {
        p.insert_instruction(std::next(x), op::identity{}, x, y);
    });
Paul's avatar
Paul committed
196
197
198
199
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx