schedule.cpp 4.86 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
8
#include <set>
Paul's avatar
Paul committed
9
10
11
12

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Paul's avatar
Paul committed
13
14
15
16
17
bool stream_free(instruction_ref ins)
{
    return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}

18
19
20
struct stream_info
{
    std::unordered_map<instruction_ref, std::size_t> ins2stream;
Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    std::unordered_map<instruction_ref, std::size_t> weights;

    void accumulate_weights(instruction_ref last, const schedule_model& model)
    {
        fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
            if(weights.count(ins) == 0)
            {
                weights[ins] =
                    std::accumulate(ins->inputs().begin(),
                                    ins->inputs().end(),
                                    model.weight(ins->get_operator()),
                                    [&](std::size_t w, instruction_ref i) { return w + self(i); });
            }
            return weights[ins];
        })(last);
    }

    void assign_streams(program& p, std::size_t streams)
    {
        const std::size_t min_partition_threshold = 2;
        for(std::size_t stream = 0; stream < streams; stream++)
        {
            fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
44
45
46
47
                // If weight is zero then stop
                if (weights[ins] == 0)
                    return;
                // Only assign streams if not already assigned
Paul's avatar
Paul committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                if(not has_stream(ins))
                    set_stream(ins, stream);
                instruction_ref child = p.end();
                std::size_t w         = 0;
                for(auto i : ins->inputs())
                {
                    const auto weight = weights[i];
                    // Skip instruction that already have stream assignment or too low of weights
                    if(has_stream(i) or weight <= min_partition_threshold)
                    {
                        self(i);
                    }
                    // Accumulate the max weight
                    else if(weight > w)
                    {
                        child = i;
                        w     = weight;
                    }
                }
                if(child != p.end())
                    self(child);
            })(std::prev(p.end()));
        }
        // Assign remaining instructions
        for(auto ins : iterator_for(p))
        {
            if(has_stream(ins))
                continue;
Paul's avatar
Paul committed
76
77
            if (weights[ins] == 0)
                continue;
Paul's avatar
Paul committed
78
79
80
            set_stream(ins, streams - 1);
        }
    }
81

Paul's avatar
Paul committed
82
    void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
83

Paul's avatar
Paul committed
84
    std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
85

Paul's avatar
Paul committed
86
    bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
87

Paul's avatar
Paul committed
88
    bool different(const std::vector<std::size_t>& v) const
89
    {
Paul's avatar
Paul committed
90
        if(v.size() < 2)
91
            return false;
Paul's avatar
Paul committed
92
        return not std::all_of(
Paul's avatar
Paul committed
93
            v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
94
95
    }

Paul's avatar
Paul committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    std::vector<std::size_t> get_input_streams(instruction_ref ins) const
    {
        std::vector<std::size_t> result;
        for (auto i:ins->inputs())
        {
            if (weights.at(i) == 0)
            {
                auto vv = get_input_streams(i);
                result.insert(result.end(), vv.begin(), vv.end());
            }
            else
            {
                result.emplace_back(get_stream(i));
            }
        }
        return result;
    }
113

Paul's avatar
Paul committed
114
115
116
117
    bool is_merge_point(instruction_ref ins) const
    { 
        return different(get_input_streams(ins)); 
    }
118
119
120

    std::vector<std::size_t> wait_for(instruction_ref ins) const
    {
Paul's avatar
Paul committed
121
122
123
124
        std::vector<std::size_t> result = get_input_streams(ins);
        std::sort( result.begin(), result.end() );
        result.erase( std::unique( result.begin(), result.end() ), result.end() );
        return result;
125
126
127
    }
};

Paul's avatar
Paul committed
128
129
void schedule::apply(program& p) const
{
Paul's avatar
Paul committed
130

131
    stream_info si;
Paul's avatar
Paul committed
132
133
134
    auto last = std::prev(p.end());
    si.accumulate_weights(last, model);
    si.assign_streams(p, model.concurrency());
135

Paul's avatar
Paul committed
136
137
    // Topo sort
    fix([&](auto self, auto ins) {
Paul's avatar
Paul committed
138
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
139
            p.move_instruction(i, p.begin());
Paul's avatar
Paul committed
140
        for(auto i : ins->inputs())
Paul's avatar
Paul committed
141
142
            self(i);
    })(last);
Paul's avatar
Paul committed
143

144
    // Schedule instructions
Paul's avatar
Paul committed
145
    for(auto ins : iterator_for(p))
146
    {
Paul's avatar
Paul committed
147
148
149
        // Only schedule instructions that have a stream
        if (not si.has_stream(ins))
            continue;
Paul's avatar
Paul committed
150
        if(si.is_merge_point(ins))
151
            model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
Paul's avatar
Paul committed
152
153
        else
            model.schedule_instruction(p, ins, si.get_stream(ins));
154
    }
Paul's avatar
Paul committed
155
156
157
158
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx