Commit fbaec470 authored by Paul's avatar Paul
Browse files

Reduce the number identity ops added

parent b211af48
...@@ -379,23 +379,46 @@ void schedule::apply(program& p) const ...@@ -379,23 +379,46 @@ void schedule::apply(program& p) const
// Add memory conflicts // Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p); auto concur_ins = si.find_concurrent_instructions(p);
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
for(auto&& merge : concur_ins) for(auto&& merge : concur_ins)
{ {
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) { dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j) if(i == j)
return; return;
if(merge.second[i].empty())
return;
if(merge.second[j].empty())
return;
for(auto ins1 : merge.second[i]) for(auto ins1 : merge.second[i])
{ {
auto args = merge.second[j]; auto p1 = std::distance(ins1, merge.first);
args.insert(args.begin(), ins1); for(auto ins2 : merge.second[j])
p.insert_instruction(merge.first, op::identity{}, args); {
if (ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge.first);
// The smaller distance means the instruction occurs later
if (p1 > p2)
conflict_table[ins2].insert(ins1);
else
conflict_table[ins1].insert(ins2);
}
} }
}); });
} }
// Remove duplicates
for(auto&& ip:conflict_table)
{
auto ins1 = ip.first;
for(auto ins2:ip.second)
if (contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1);
}
for(auto&& ip:conflict_table)
{
if (ip.second.empty())
continue;
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args);
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -140,6 +140,21 @@ struct schedule_model_test ...@@ -140,6 +140,21 @@ struct schedule_model_test
} }
}; };
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
}
return false;
}
struct schedule_target struct schedule_target
{ {
schedule_model_test model{}; schedule_model_test model{};
...@@ -162,35 +177,29 @@ struct schedule_target ...@@ -162,35 +177,29 @@ struct schedule_target
} }
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; } bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y) void check_conflicts(migraphx::program& p,
{
for(auto ins : migraphx::iterator_for(p))
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
}
return false;
}
void check_conflicts(migraphx::program& p,
std::vector<std::vector<migraphx::instruction_ref>> conflicts, std::vector<std::vector<migraphx::instruction_ref>> conflicts,
bool result = true) bool result = true)
{ {
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) { migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j) if(i == j)
return; return;
for(auto ins1 : conflicts[i]) for(auto ins1 : conflicts[i])
{
for(auto ins2 : conflicts[j]) for(auto ins2 : conflicts[j])
CHECK(check_conflicts(p, ins1, ins2) == result); {
// If both instructions are on the same stream then dont check for a conflict
if (has_stream(ins1) and has_stream(ins2) and get_stream(ins1) == get_stream(ins2))
continue;
CHECK(::check_conflicts(p, ins1, ins2) == result);
}
}
}); });
} }
};
template <class T> template <class T>
std::vector<T> sorted(std::vector<T> x) std::vector<T> sorted(std::vector<T> x)
...@@ -292,7 +301,7 @@ TEST_CASE(zero_record) ...@@ -292,7 +301,7 @@ TEST_CASE(zero_record)
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)})); get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2)); EXPECT(check_conflicts(p, onep1, onep2));
check_conflicts(p, {{onep1, onei1}, {onep2, onei2}}); t.check_conflicts(p, {{onep1, onei1}, {onep2, onei2}});
} }
TEST_CASE(zero_merge1) TEST_CASE(zero_merge1)
...@@ -397,7 +406,7 @@ TEST_CASE(double_entry) ...@@ -397,7 +406,7 @@ TEST_CASE(double_entry)
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)})); get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
check_conflicts(p, {{onep, one}, {twop, two}}); t.check_conflicts(p, {{onep, one}, {twop, two}});
} }
TEST_CASE(two_branches) TEST_CASE(two_branches)
...@@ -416,7 +425,7 @@ TEST_CASE(two_branches) ...@@ -416,7 +425,7 @@ TEST_CASE(two_branches)
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
} }
TEST_CASE(four_branches) TEST_CASE(four_branches)
...@@ -444,7 +453,7 @@ TEST_CASE(four_branches) ...@@ -444,7 +453,7 @@ TEST_CASE(four_branches)
t.get_stream(c2.back()), t.get_stream(c2.back()),
t.get_stream(c3.back()), t.get_stream(c3.back()),
t.get_stream(i1)})); t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, {i1}}); t.check_conflicts(p, {c1, c2, c3, {i1}});
} }
TEST_CASE(five_branches) TEST_CASE(five_branches)
...@@ -475,8 +484,8 @@ TEST_CASE(five_branches) ...@@ -475,8 +484,8 @@ TEST_CASE(five_branches)
t.get_stream(c2.back()), t.get_stream(c2.back()),
t.get_stream(c3.back()), t.get_stream(c3.back()),
t.get_stream(i1)})); t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, c4}); t.check_conflicts(p, {c1, c2, c3, c4});
check_conflicts(p, {c1, c2, c3, {i1}}); t.check_conflicts(p, {c1, c2, c3, {i1}});
} }
TEST_CASE(four_branches_eq) TEST_CASE(four_branches_eq)
...@@ -502,7 +511,7 @@ TEST_CASE(four_branches_eq) ...@@ -502,7 +511,7 @@ TEST_CASE(four_branches_eq)
get_wait_for( get_wait_for(
t.get_stream(binary), t.get_stream(binary),
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)})); {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}}); t.check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
} }
TEST_CASE(seq_merge) TEST_CASE(seq_merge)
...@@ -527,7 +536,7 @@ TEST_CASE(seq_merge) ...@@ -527,7 +536,7 @@ TEST_CASE(seq_merge)
EXPECT(t.get_stream(binary1) == t.get_stream(c1.back())); EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) == EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) != t.get_stream(c2.back())); EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2) for(auto ins : c2)
...@@ -535,7 +544,7 @@ TEST_CASE(seq_merge) ...@@ -535,7 +544,7 @@ TEST_CASE(seq_merge)
EXPECT(t.get_stream(binary2) == 0); EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); t.check_conflicts(p, {c2, {i2}});
} }
TEST_CASE(par_merge) TEST_CASE(par_merge)
...@@ -565,17 +574,17 @@ TEST_CASE(par_merge) ...@@ -565,17 +574,17 @@ TEST_CASE(par_merge)
EXPECT(t.get_stream(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2)); EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}}); t.check_conflicts(p, {c1, {i1}, c2, {i2}});
} }
TEST_CASE(inner_par_merge) TEST_CASE(inner_par_merge)
...@@ -616,17 +625,17 @@ TEST_CASE(inner_par_merge) ...@@ -616,17 +625,17 @@ TEST_CASE(inner_par_merge)
EXPECT(t.get_stream(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2)); EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}}); t.check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}});
} }
TEST_CASE(par_merge_multi_entry) TEST_CASE(par_merge_multi_entry)
...@@ -658,17 +667,17 @@ TEST_CASE(par_merge_multi_entry) ...@@ -658,17 +667,17 @@ TEST_CASE(par_merge_multi_entry)
EXPECT(t.get_stream(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2)); EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}}); t.check_conflicts(p, {c1, {i1}, c2, {i2}});
} }
TEST_CASE(inner_split1) TEST_CASE(inner_split1)
...@@ -696,7 +705,7 @@ TEST_CASE(inner_split1) ...@@ -696,7 +705,7 @@ TEST_CASE(inner_split1)
EXPECT(get_wait_for(s1).empty()); EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here // TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty()); // EXPECT(get_wait_for(s2).empty());
check_conflicts(p, {c1, {i1}, {s1}, {s2}}); t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
} }
TEST_CASE(inner_split2) TEST_CASE(inner_split2)
...@@ -722,7 +731,7 @@ TEST_CASE(inner_split2) ...@@ -722,7 +731,7 @@ TEST_CASE(inner_split2)
get_wait_for(t.get_stream(output), get_wait_for(t.get_stream(output),
{t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())})); {t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())}));
EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())})); EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())}));
check_conflicts(p, {c1, {i1}, s1, s2}); t.check_conflicts(p, {c1, {i1}, s1, s2});
} }
TEST_CASE(inception_resnet) TEST_CASE(inception_resnet)
...@@ -745,7 +754,7 @@ TEST_CASE(inception_resnet) ...@@ -745,7 +754,7 @@ TEST_CASE(inception_resnet)
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
EXPECT(t.get_stream(output) == 0); EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output).empty()); EXPECT(get_wait_for(output).empty());
check_conflicts(p, {c1, {i1}}); t.check_conflicts(p, {c1, {i1}});
} }
TEST_CASE(inception1) TEST_CASE(inception1)
...@@ -866,7 +875,7 @@ TEST_CASE(inception1) ...@@ -866,7 +875,7 @@ TEST_CASE(inception1)
get_wait_for(t.get_stream(output), get_wait_for(t.get_stream(output),
{t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)})); {t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)}));
check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}}); t.check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment