schedule_test.cpp 7.45 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#include <migraphx/schedule.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
6
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <basic_ops.hpp>
#include <test.hpp>

struct unary_op
{
    std::string name() const { return "unary"; }
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
    int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
};

Paul's avatar
Paul committed
30
struct nary_op
Paul's avatar
Paul committed
31
{
Paul's avatar
Paul committed
32
    std::string name() const { return "nary"; }
Paul's avatar
Paul committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    migraphx::argument
    compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
    {
        if(args.empty())
            return {};
        return args.front();
    }

    migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
    {
        if(inputs.empty())
            return {};
        return inputs.front();
    }
};

Paul's avatar
Paul committed
49
50
51
52
53
54
55
56
57
58
59
struct wait_event
{
    std::vector<std::size_t> wait_for;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::pack(f(self.wait_for, "wait_for"));
    }
    std::string name() const { return "wait_event"; }
    migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }

Paul's avatar
Paul committed
60
61
62
    migraphx::argument compute(migraphx::context&,
                               const migraphx::shape&,
                               const std::vector<migraphx::argument>&) const
Paul's avatar
Paul committed
63
64
65
66
67
    {
        return {};
    }
};

Paul's avatar
Paul committed
68
69
70
71
72
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;

struct schedule_model_test
{
    instruction_map* ins2stream;
Paul's avatar
Paul committed
73
74
    std::size_t concurrency() const { return 4; }
    void
Paul's avatar
Paul committed
75
    schedule_instruction(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
Paul's avatar
Paul committed
76
77
78
79
80
81
82
83
    {
        (*ins2stream)[ins] = n;
    }
    void wait(migraphx::program& p,
              migraphx::instruction_ref ins,
              std::size_t wait_on,
              const std::vector<std::size_t>& wait_for) const
    {
Paul's avatar
Paul committed
84
85
        (*ins2stream)[ins] = wait_on;
        p.insert_instruction(ins, wait_event{wait_for});
Paul's avatar
Paul committed
86
87
88
    }
    std::size_t weight(const migraphx::operation& op) const
    {
Paul's avatar
Paul committed
89
        if(op.name() == "binary" or op.name() == "unary")
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            return 4;
        else
            return 1;
    }
};

struct schedule_target
{
    instruction_map* ins2stream;
    std::string name() const { return "schedule"; }
    std::vector<migraphx::pass> get_passes(migraphx::context&) const
    {
        return {migraphx::schedule{schedule_model_test{ins2stream}}};
    }
    migraphx::context get_context() const { return {}; }
};

bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
Paul's avatar
Paul committed
109
    for(auto ins : migraphx::iterator_for(p))
Paul's avatar
Paul committed
110
    {
Paul's avatar
Paul committed
111
        if(ins->name() != "identity")
Paul's avatar
Paul committed
112
            continue;
Paul's avatar
Paul committed
113
        if(ins->inputs().size() != 2)
Paul's avatar
Paul committed
114
            continue;
Paul's avatar
Paul committed
115
        if(ins->inputs() == std::vector<migraphx::instruction_ref>{x, y})
Paul's avatar
Paul committed
116
            return true;
Paul's avatar
Paul committed
117
        if(ins->inputs() == std::vector<migraphx::instruction_ref>{y, x})
Paul's avatar
Paul committed
118
119
120
121
122
            return true;
    }
    return false;
}

Paul's avatar
Paul committed
123
124
void check_conflicts(migraphx::program& p,
                     std::vector<std::vector<migraphx::instruction_ref>> conflicts)
Paul's avatar
Paul committed
125
126
{
    migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
Paul's avatar
Paul committed
127
        if(i == j)
Paul's avatar
Paul committed
128
            return;
Paul's avatar
Paul committed
129
130
        for(auto ins1 : conflicts[i])
            for(auto ins2 : conflicts[j])
Paul's avatar
Paul committed
131
132
133
134
                CHECK(check_conflicts(p, ins1, ins2));
    });
}

Paul's avatar
Paul committed
135
136
137
138
139
140
141
142
143
144
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
    wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
    std::sort(wait_for.begin(), wait_for.end());
    return wait_for;
}

std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
    auto wait_ins = std::prev(ins);
Paul's avatar
Paul committed
145
    if(wait_ins->name() != "wait_event")
Paul's avatar
Paul committed
146
147
148
149
150
151
        return {};
    auto wf = migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
    std::sort(wf.begin(), wf.end());
    return wf;
}

Paul's avatar
Paul committed
152
153
154
template <class T>
std::vector<migraphx::instruction_ref>
chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
Paul's avatar
Paul committed
155
156
{
    std::vector<migraphx::instruction_ref> result;
Paul's avatar
Paul committed
157
    for(std::size_t i = 0; i < n; i++)
Paul's avatar
Paul committed
158
159
160
161
162
163
164
    {
        result.push_back(p.add_instruction(x, input));
        input = result.back();
    }
    return result;
}

Paul's avatar
Paul committed
165
166
167
168
169
TEST_CASE(single_entry)
{
    instruction_map stream;
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
170
171
    auto onep1  = p.add_instruction(unary_op{}, one);
    auto onep2  = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
172
    auto binary = p.add_instruction(nary_op{}, onep1, onep2);
Paul's avatar
Paul committed
173
    p.compile(schedule_target{&stream});
Paul's avatar
Paul committed
174
    EXPECT(stream.count(one) == 0);
Paul's avatar
Paul committed
175
176
177
178
179
180
181
    EXPECT(stream.at(onep1) != stream.at(onep2));
    EXPECT(stream.at(binary) == 0);
    EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]}));
    EXPECT(check_conflicts(p, onep1, onep2));
}

TEST_CASE(double_entry)
Paul's avatar
Paul committed
182
{
Paul's avatar
Paul committed
183
    instruction_map stream;
Paul's avatar
Paul committed
184
    migraphx::program p;
Paul's avatar
Paul committed
185
186
187
188
    auto one    = p.add_literal(1);
    auto two    = p.add_literal(2);
    auto onep   = p.add_instruction(unary_op{}, one);
    auto twop   = p.add_instruction(unary_op{}, two);
Paul's avatar
Paul committed
189
    auto binary = p.add_instruction(nary_op{}, onep, twop);
Paul's avatar
Paul committed
190
    p.compile(schedule_target{&stream});
Paul's avatar
Paul committed
191
192
    EXPECT(stream.count(one) == 0);
    EXPECT(stream.count(two) == 0);
Paul's avatar
Paul committed
193
194
195
    EXPECT(stream.at(onep) != stream.at(twop));
    EXPECT(stream.at(binary) == 0);
    EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]}));
Paul's avatar
Paul committed
196
197
198
199
200
201
202
203
    // EXPECT(check_conflicts(p, onep, twop));
}

TEST_CASE(two_weights)
{
    instruction_map stream;
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
204
205
    auto c1     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
206
207
208
209
    auto binary = p.add_instruction(nary_op{}, i1, c1.back());
    p.compile(schedule_target{&stream});
    EXPECT(stream.count(one) == 0);
    EXPECT(stream.at(i1) == 1);
Paul's avatar
Paul committed
210
    for(auto ins : c1)
Paul's avatar
Paul committed
211
212
213
214
215
216
217
218
219
220
221
        EXPECT(stream.at(ins) == 0);
    EXPECT(stream.at(binary) == 0);
    EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]}));
    check_conflicts(p, {c1, {i1}});
}

TEST_CASE(four_weights)
{
    instruction_map stream;
    migraphx::program p;
    auto one    = p.add_literal(1);
Paul's avatar
Paul committed
222
223
224
225
    auto c1     = chain(p, 4, unary_op{}, one);
    auto c2     = chain(p, 3, unary_op{}, one);
    auto c3     = chain(p, 2, unary_op{}, one);
    auto i1     = p.add_instruction(unary_op{}, one);
Paul's avatar
Paul committed
226
    auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
Paul's avatar
Paul committed
227
228
229
    p.compile(schedule_target{&stream});
    EXPECT(stream.count(one) == 0);
    EXPECT(stream.at(i1) == 3);
Paul's avatar
Paul committed
230
231
232
233
234
235
    for(auto ins : c1)
        EXPECT(stream.at(ins) == 0);
    for(auto ins : c2)
        EXPECT(stream.at(ins) == 1);
    for(auto ins : c3)
        EXPECT(stream.at(ins) == 2);
Paul's avatar
Paul committed
236
    EXPECT(stream.at(binary) == 0);
Paul's avatar
Paul committed
237
238
239
    EXPECT(get_wait_for(binary) ==
           get_wait_for(stream[binary],
                        {stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
Paul's avatar
Paul committed
240
    check_conflicts(p, {c1, c2, c3, {i1}});
Paul's avatar
Paul committed
241
242
243
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }