schedule.cpp 4.82 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
8
#include <set>
Paul's avatar
Paul committed
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
13
14
15
16
17
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

18
19
20
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
                                    model.weight(ins->get_operator()),
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
44
                // If weight is zero then stop
Paul's avatar
Paul committed
45
                if(weights[ins] == 0)
Paul's avatar
Paul committed
46
47
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                if(not has_stream(ins))
                    set_stream(ins, stream);
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
                    const auto weight = weights[i];
                    // Skip instruction that already have stream assignment or too low of weights
                    if(has_stream(i) or weight <= min_partition_threshold)
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
76
            if(weights[ins] == 0)
Paul's avatar
Paul committed
77
                continue;
Paul's avatar
Paul committed
78
79
80
            set_stream(ins, streams - 1);
        }
    }
81

Paul's avatar
Paul committed
82
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
83

Paul's avatar
Paul committed
84
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
85

Paul's avatar
Paul committed
86
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
87

Paul's avatar
Paul committed
88
    bool different(const std::vector<std::size_t>& v) const
89
    {
Paul's avatar
Paul committed
90
        if(v.size() < 2)
91
            return false;
Paul's avatar
Paul committed
92
        return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
93
94
    }

Paul's avatar
Paul committed
95
96
97
    std::vector<std::size_t> get_input_streams(instruction_ref ins) const
    {
        std::vector<std::size_t> result;
Paul's avatar
Paul committed
98
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
99
        {
Paul's avatar
Paul committed
100
            if(weights.at(i) == 0)
Paul's avatar
Paul committed
101
102
103
104
105
106
107
108
109
110
111
            {
                auto vv = get_input_streams(i);
                result.insert(result.end(), vv.begin(), vv.end());
            }
            else
            {
                result.emplace_back(get_stream(i));
            }
        }
        return result;
    }
112

Paul's avatar
Paul committed
113
    bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
114
115
116

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
117
        std::vector<std::size_t> result = get_input_streams(ins);
Paul's avatar
Paul committed
118
119
        std::sort(result.begin(), result.end());
        result.erase(std::unique(result.begin(), result.end()), result.end());
Paul's avatar
Paul committed
120
        return result;
121
122
123
    }
};

Paul's avatar
Paul committed
124
125
void schedule::apply(program& p) const
{
Paul's avatar
Paul committed
126

127
    stream_info si;
Paul's avatar
Paul committed
128
129
130
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
131

Paul's avatar
Paul committed
132
133
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
134
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
135
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
136
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
137
138
            self(i);
    })(last);
Paul's avatar
Paul committed
139

140
    // Schedule instructions
Paul's avatar
Paul committed
141
    for(auto ins : iterator_for(p))
142
    {
Paul's avatar
Paul committed
143
        // Only schedule instructions that have a stream
Paul's avatar
Paul committed
144
        if(not si.has_stream(ins))
Paul's avatar
Paul committed
145
            continue;
Paul's avatar
Paul committed
146
        if(si.is_merge_point(ins))
147
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
Paul's avatar
Paul committed
148
149
        else
            model.schedule_instruction(p, ins, si.get_stream(ins));
150
    }
Paul's avatar
Paul committed
151
152
153
154
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx