schedule.cpp 6.15 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
#include <migraphx/operators.hpp>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
Paul's avatar
Paul committed
9
#include <unordered_set>
10
#include <set>
Paul's avatar
Paul committed
11
12
13
14

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
15
16
17
18
19
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

20
21
22
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
                                    model.weight(ins->get_operator()),
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
46
                // If weight is zero then stop
Paul's avatar
Paul committed
47
                if(weights[ins] == 0)
Paul's avatar
Paul committed
48
49
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                if(not has_stream(ins))
                    set_stream(ins, stream);
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
                    const auto weight = weights[i];
                    // Skip instruction that already have stream assignment or too low of weights
                    if(has_stream(i) or weight <= min_partition_threshold)
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
78
            if(weights[ins] == 0)
Paul's avatar
Paul committed
79
                continue;
Paul's avatar
Paul committed
80
81
82
            set_stream(ins, streams - 1);
        }
    }
83

Paul's avatar
Paul committed
84
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
85

Paul's avatar
Paul committed
86
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
87

Paul's avatar
Paul committed
88
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
89

Paul's avatar
Paul committed
90
    bool different(const std::vector<std::size_t>& v) const
91
    {
Paul's avatar
Paul committed
92
        if(v.size() < 2)
93
            return false;
Paul's avatar
Paul committed
94
        return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
95
96
    }

Paul's avatar
Paul committed
97
    template <class Selector>
Paul's avatar
Paul committed
98
    std::vector<std::size_t> get_streams(instruction_ref ins, Selector select) const
Paul's avatar
Paul committed
99
100
    {
        std::vector<std::size_t> result;
Paul's avatar
Paul committed
101
        for(auto i : select(ins))
Paul's avatar
Paul committed
102
        {
Paul's avatar
Paul committed
103
            if(weights.at(i) == 0)
Paul's avatar
Paul committed
104
105
106
107
108
109
110
111
112
113
114
            {
                auto vv = get_input_streams(i);
                result.insert(result.end(), vv.begin(), vv.end());
            }
            else
            {
                result.emplace_back(get_stream(i));
            }
        }
        return result;
    }
115

Paul's avatar
Paul committed
116
117
    std::vector<std::size_t> get_input_streams(instruction_ref ins) const
    {
Paul's avatar
Paul committed
118
        return get_streams(ins, [](auto i) { return i->inputs(); });
Paul's avatar
Paul committed
119
120
121
122
    }

    std::vector<std::size_t> get_output_streams(instruction_ref ins) const
    {
Paul's avatar
Paul committed
123
        return get_streams(ins, [](auto i) { return i->outputs(); });
Paul's avatar
Paul committed
124
125
    }

Paul's avatar
Paul committed
126
    bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
Paul's avatar
Paul committed
127

Paul's avatar
Paul committed
128
    bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); }
129
130
131

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
132
        std::vector<std::size_t> result = get_input_streams(ins);
Paul's avatar
Paul committed
133
134
        std::sort(result.begin(), result.end());
        result.erase(std::unique(result.begin(), result.end()), result.end());
Paul's avatar
Paul committed
135
        return result;
136
    }
Paul's avatar
Paul committed
137

Paul's avatar
Paul committed
138
    template <class F>
Paul's avatar
Paul committed
139
140
141
142
143
    void find_concurrent_instructions(program& p, F f)
    {
        std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
        for(auto ins : iterator_for(p))
        {
Paul's avatar
Paul committed
144
            if(weights[ins] == 0)
Paul's avatar
Paul committed
145
146
147
                continue;
            for(auto&& arg : ins->inputs())
            {
Paul's avatar
Paul committed
148
                if(is_split_point(arg))
Paul's avatar
Paul committed
149
150
151
152
153
154
155
156
157
158
159
                    split_from[ins].insert(arg);
                split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
            }

            // Collect concur instructions for each split point.
            for(auto& split : split_from[ins])
            {
                f(ins, split);
            }
        }
    }
160
161
};

Paul's avatar
Paul committed
162
163
void schedule::apply(program& p) const
{
164
    stream_info si;
Paul's avatar
Paul committed
165
166
167
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
168

Paul's avatar
Paul committed
169
170
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
171
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
172
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
173
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
174
175
            self(i);
    })(last);
Paul's avatar
Paul committed
176

177
    // Schedule instructions
Paul's avatar
Paul committed
178
    for(auto ins : iterator_for(p))
179
    {
Paul's avatar
Paul committed
180
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
181
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
182
            continue;
Paul's avatar
Paul committed
183
        if(si.is_merge_point(ins))
184
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
Paul's avatar
Paul committed
185
186
        else
            model.schedule_instruction(p, ins, si.get_stream(ins));
187
    }
Paul's avatar
Paul committed
188

Paul's avatar
Paul committed
189
190
    si.find_concurrent_instructions(
        p, [&](auto x, auto y) { p.insert_instruction(std::next(x), op::identity{}, x, y); });
Paul's avatar
Paul committed
191
192
193
194
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx