schedule.cpp 4.53 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
8
#include <set>
Paul's avatar
Paul committed
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
13
14
15
16
17
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

18
19
20
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
                                    model.weight(ins->get_operator()),
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
                // Only assign streams fi not already assigned
                if(not has_stream(ins))
                    set_stream(ins, stream);
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
                    const auto weight = weights[i];
                    // Skip instruction that already have stream assignment or too low of weights
                    if(has_stream(i) or weight <= min_partition_threshold)
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
            set_stream(ins, streams - 1);
        }
    }
76

Paul's avatar
Paul committed
77
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
78

Paul's avatar
Paul committed
79
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
80

Paul's avatar
Paul committed
81
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
82
83
84

    bool different(const std::vector<instruction_ref>& v) const
    {
Paul's avatar
Paul committed
85
        if(v.size() < 2)
86
87
            return false;
        auto stream = get_stream(v.front());
Paul's avatar
Paul committed
88
89
        return not std::all_of(
            v.begin(), v.end(), [&](instruction_ref x) { return get_stream(x) == stream; });
90
91
    }

Paul's avatar
Paul committed
92
    bool is_split_point(instruction_ref ins) const { return different(ins->outputs()); }
93

Paul's avatar
Paul committed
94
    bool is_merge_point(instruction_ref ins) const { return different(ins->inputs()); }
95
96
97
98
99

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
        std::set<std::size_t> result;
        auto s = get_stream(ins);
Paul's avatar
Paul committed
100
        for(auto i : ins->inputs())
101
102
        {
            auto stream = get_stream(i);
Paul's avatar
Paul committed
103
            if(stream != s)
104
105
106
107
108
109
                result.insert(stream);
        }
        return {result.begin(), result.end()};
    }
};

Paul's avatar
Paul committed
110
111
void schedule::apply(program& p) const
{
Paul's avatar
Paul committed
112
    
113
    stream_info si;
Paul's avatar
Paul committed
114
115
116
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
117

Paul's avatar
Paul committed
118
119
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
120
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
121
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
122
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
123
124
            self(i);
    })(last);
Paul's avatar
Paul committed
125

126
    // Schedule instructions
Paul's avatar
Paul committed
127
    for(auto ins : iterator_for(p))
128
    {
Paul's avatar
Paul committed
129
        if(si.is_merge_point(ins))
130
131
132
133
134
135
        {
            assert(not si.wait_for(ins).empty());
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
            continue;
        }
        // Skip scheduling instructions with no context
Paul's avatar
Paul committed
136
        if(is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@')
137
138
139
            continue;
        model.schedule_instruction(p, ins, si.get_stream(ins));
    }
Paul's avatar
Paul committed
140
141
142
143
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx