schedule.cpp 4.07 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
8
#include <set>
Paul's avatar
Paul committed
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

13
14
15
16
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;

Paul's avatar
Paul committed
17
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
18

Paul's avatar
Paul committed
19
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
20

Paul's avatar
Paul committed
21
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
22
23
24

    bool different(const std::vector<instruction_ref>& v) const
    {
Paul's avatar
Paul committed
25
        if(v.size() < 2)
26
27
            return false;
        auto stream = get_stream(v.front());
Paul's avatar
Paul committed
28
29
        return not std::all_of(
            v.begin(), v.end(), [&](instruction_ref x) { return get_stream(x) == stream; });
30
31
    }

Paul's avatar
Paul committed
32
    bool is_split_point(instruction_ref ins) const { return different(ins->outputs()); }
33

Paul's avatar
Paul committed
34
    bool is_merge_point(instruction_ref ins) const { return different(ins->inputs()); }
35
36
37
38
39

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
        std::set<std::size_t> result;
        auto s = get_stream(ins);
Paul's avatar
Paul committed
40
        for(auto i : ins->inputs())
41
42
        {
            auto stream = get_stream(i);
Paul's avatar
Paul committed
43
            if(stream != s)
44
45
46
47
48
49
                result.insert(stream);
        }
        return {result.begin(), result.end()};
    }
};

Paul's avatar
Paul committed
50
51
void schedule::apply(program& p) const
{
52
53
    const std::size_t min_partition_threshold = 2;

Paul's avatar
Paul committed
54
55
56
57
    // Compute accumulated weights
    std::unordered_map<instruction_ref, std::size_t> weights;
    auto last = std::prev(p.end());
    fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
Paul's avatar
Paul committed
58
        if(weights.count(ins) == 0)
Paul's avatar
Paul committed
59
        {
Paul's avatar
Paul committed
60
61
62
63
64
            weights[ins] =
                std::accumulate(ins->inputs().begin(),
                                ins->inputs().end(),
                                model.weight(ins->get_operator()),
                                [&](std::size_t w, instruction_ref i) { return w + self(i); });
Paul's avatar
Paul committed
65
66
67
68
        }
        return weights[ins];
    })(last);

69
70
71
    // Assign streams
    auto streams = model.concurrency();
    stream_info si;
Paul's avatar
Paul committed
72
    for(std::size_t stream = 0; stream < streams; stream++)
73
74
75
    {
        fix([&](auto self, auto ins) {
            // Only assign streams fi not already assigned
Paul's avatar
Paul committed
76
            if(not si.has_stream(ins))
77
78
                si.set_stream(ins, stream);
            instruction_ref child = p.end();
Paul's avatar
Paul committed
79
80
            std::size_t w         = 0;
            for(auto i : ins->inputs())
81
82
83
            {
                const auto weight = weights[i];
                // Skip instruction that already have stream assignment or too low of weights
Paul's avatar
Paul committed
84
                if(si.has_stream(i) or weight <= min_partition_threshold)
85
86
87
88
                {
                    self(i);
                }
                // Accumulate the max weight
Paul's avatar
Paul committed
89
                else if(weight > w)
90
91
                {
                    child = i;
Paul's avatar
Paul committed
92
                    w     = weight;
93
94
                }
            }
Paul's avatar
Paul committed
95
            if(child != p.end())
96
97
98
99
                self(child);
        })(last);
    }
    // Assign remaining instructions
Paul's avatar
Paul committed
100
    for(auto ins : iterator_for(p))
101
    {
Paul's avatar
Paul committed
102
        if(si.has_stream(ins))
103
            continue;
Paul's avatar
Paul committed
104
        si.set_stream(ins, streams - 1);
105
106
    }

Paul's avatar
Paul committed
107
108
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
109
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
110
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
111
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
112
113
            self(i);
    })(last);
Paul's avatar
Paul committed
114

115
    // Schedule instructions
Paul's avatar
Paul committed
116
    for(auto ins : iterator_for(p))
117
    {
Paul's avatar
Paul committed
118
        if(si.is_merge_point(ins))
119
120
121
122
123
124
        {
            assert(not si.wait_for(ins).empty());
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
            continue;
        }
        // Skip scheduling instructions with no context
Paul's avatar
Paul committed
125
        if(is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@')
126
127
128
            continue;
        model.schedule_instruction(p, ins, si.get_stream(ins));
    }
Paul's avatar
Paul committed
129
130
131
132
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx