Commit da8221c9 authored by Paul's avatar Paul
Browse files

Formatting

parent 65cf0713
...@@ -49,7 +49,8 @@ struct nary_op ...@@ -49,7 +49,8 @@ struct nary_op
struct wait_event struct wait_event
{ {
std::shared_ptr<std::vector<std::size_t>> wait_for = std::make_shared<std::vector<std::size_t>>(); std::shared_ptr<std::vector<std::size_t>> wait_for =
std::make_shared<std::vector<std::size_t>>();
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -69,22 +70,23 @@ struct wait_event ...@@ -69,22 +70,23 @@ struct wait_event
}; };
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>; using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
using wait_map = std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>; using wait_map =
std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
struct schedule_model_test struct schedule_model_test
{ {
std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>(); std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream = std::make_shared<std::unordered_map<std::size_t, std::size_t>>(); std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream =
std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>(); std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
std::size_t concurrency() const { return 4; } std::size_t concurrency() const { return 4; }
void void sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
{ {
(*ins2stream)[ins] = n; (*ins2stream)[ins] = n;
} }
void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{ {
if (ins2wait_for->count(ins) == 0) if(ins2wait_for->count(ins) == 0)
{ {
auto event = wait_event{}; auto event = wait_event{};
p.insert_instruction(ins, event); p.insert_instruction(ins, event);
...@@ -115,15 +117,9 @@ struct schedule_target ...@@ -115,15 +117,9 @@ struct schedule_target
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins) std::size_t get_stream(migraphx::instruction_ref ins) { return model.ins2stream->at(ins); }
{
return model.ins2stream->at(ins);
}
bool has_stream(migraphx::instruction_ref ins) bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
{
return model.ins2stream->count(ins) > 0;
}
}; };
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y) bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
...@@ -216,7 +212,8 @@ TEST_CASE(single_entry) ...@@ -216,7 +212,8 @@ TEST_CASE(single_entry)
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)})); EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2)); EXPECT(check_conflicts(p, onep1, onep2));
} }
...@@ -272,7 +269,8 @@ TEST_CASE(double_entry) ...@@ -272,7 +269,8 @@ TEST_CASE(double_entry)
EXPECT(not t.has_stream(two)); EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop)); EXPECT(t.get_stream(onep) != t.get_stream(twop));
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)})); EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
// EXPECT(check_conflicts(p, onep, twop)); // EXPECT(check_conflicts(p, onep, twop));
} }
...@@ -290,7 +288,8 @@ TEST_CASE(two_branches) ...@@ -290,7 +288,8 @@ TEST_CASE(two_branches)
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)})); EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
} }
...@@ -314,9 +313,11 @@ TEST_CASE(four_branches) ...@@ -314,9 +313,11 @@ TEST_CASE(four_branches)
for(auto ins : c3) for(auto ins : c3)
EXPECT(t.get_stream(ins) == 2); EXPECT(t.get_stream(ins) == 2);
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()),
{t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)})); t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, {i1}}); check_conflicts(p, {c1, c2, c3, {i1}});
} }
...@@ -343,9 +344,11 @@ TEST_CASE(five_branches) ...@@ -343,9 +344,11 @@ TEST_CASE(five_branches)
for(auto ins : c4) for(auto ins : c4)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary),
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()),
{t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)})); t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, c4}); check_conflicts(p, {c1, c2, c3, c4});
check_conflicts(p, {c1, c2, c3, {i1}}); check_conflicts(p, {c1, c2, c3, {i1}});
} }
...@@ -362,14 +365,17 @@ TEST_CASE(four_branches_eq) ...@@ -362,14 +365,17 @@ TEST_CASE(four_branches_eq)
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4); auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(sorted<std::size_t>( EXPECT(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) == sorted<std::size_t>(
unique<std::size_t>( {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)})); unique<std::size_t>(
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT( EXPECT(
get_wait_for(binary) == get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)})); get_wait_for(
t.get_stream(binary),
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}}); check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
} }
...@@ -393,14 +399,16 @@ TEST_CASE(seq_merge) ...@@ -393,14 +399,16 @@ TEST_CASE(seq_merge)
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary1) == 3); EXPECT(t.get_stream(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 3); EXPECT(t.get_stream(i2) == 3);
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary2) == 0); EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); check_conflicts(p, {c2, {i2}});
} }
...@@ -429,14 +437,16 @@ TEST_CASE(par_merge) ...@@ -429,14 +437,16 @@ TEST_CASE(par_merge)
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 1); EXPECT(t.get_stream(i2) == 1);
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2)); EXPECT(check_conflicts(p, binary1, binary2));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment