"src/vscode:/vscode.git/clone" did not exist on "b3086ac2606d4b6999788f7faf06afa30406e44e"
Commit fbaec470 authored by Paul's avatar Paul
Browse files

Reduce the number identity ops added

parent b211af48
......@@ -379,23 +379,46 @@ void schedule::apply(program& p) const
// Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p);
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
for(auto&& merge : concur_ins)
{
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j)
return;
if(merge.second[i].empty())
return;
if(merge.second[j].empty())
return;
for(auto ins1 : merge.second[i])
{
auto args = merge.second[j];
args.insert(args.begin(), ins1);
p.insert_instruction(merge.first, op::identity{}, args);
auto p1 = std::distance(ins1, merge.first);
for(auto ins2 : merge.second[j])
{
if (ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge.first);
// The smaller distance means the instruction occurs later
if (p1 > p2)
conflict_table[ins2].insert(ins1);
else
conflict_table[ins1].insert(ins2);
}
}
});
}
// Remove duplicates
for(auto&& ip:conflict_table)
{
auto ins1 = ip.first;
for(auto ins2:ip.second)
if (contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1);
}
for(auto&& ip:conflict_table)
{
if (ip.second.empty())
continue;
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -140,6 +140,21 @@ struct schedule_model_test
}
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
}
return false;
}
struct schedule_target
{
schedule_model_test model{};
......@@ -162,35 +177,29 @@ struct schedule_target
}
bool has_stream(migraphx::instruction_ref ins) { return model.ins2stream->count(ins) > 0; }
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
{
for(auto ins : migraphx::iterator_for(p))
void check_conflicts(migraphx::program& p,
std::vector<std::vector<migraphx::instruction_ref>> conflicts,
bool result = true)
{
if(ins->name() != "identity")
continue;
if(not migraphx::contains(ins->inputs(), x))
continue;
if(not migraphx::contains(ins->inputs(), y))
continue;
return true;
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : conflicts[i])
{
for(auto ins2 : conflicts[j])
{
// If both instructions are on the same stream then dont check for a conflict
if (has_stream(ins1) and has_stream(ins2) and get_stream(ins1) == get_stream(ins2))
continue;
CHECK(::check_conflicts(p, ins1, ins2) == result);
}
}
});
}
return false;
}
};
void check_conflicts(migraphx::program& p,
std::vector<std::vector<migraphx::instruction_ref>> conflicts,
bool result = true)
{
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : conflicts[i])
for(auto ins2 : conflicts[j])
CHECK(check_conflicts(p, ins1, ins2) == result);
});
}
template <class T>
std::vector<T> sorted(std::vector<T> x)
......@@ -292,7 +301,7 @@ TEST_CASE(zero_record)
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
check_conflicts(p, {{onep1, onei1}, {onep2, onei2}});
t.check_conflicts(p, {{onep1, onei1}, {onep2, onei2}});
}
TEST_CASE(zero_merge1)
......@@ -397,7 +406,7 @@ TEST_CASE(double_entry)
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
check_conflicts(p, {{onep, one}, {twop, two}});
t.check_conflicts(p, {{onep, one}, {twop, two}});
}
TEST_CASE(two_branches)
......@@ -416,7 +425,7 @@ TEST_CASE(two_branches)
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(four_branches)
......@@ -444,7 +453,7 @@ TEST_CASE(four_branches)
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, {i1}});
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(five_branches)
......@@ -475,8 +484,8 @@ TEST_CASE(five_branches)
t.get_stream(c2.back()),
t.get_stream(c3.back()),
t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, c4});
check_conflicts(p, {c1, c2, c3, {i1}});
t.check_conflicts(p, {c1, c2, c3, c4});
t.check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(four_branches_eq)
......@@ -502,7 +511,7 @@ TEST_CASE(four_branches_eq)
get_wait_for(
t.get_stream(binary),
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
t.check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}
TEST_CASE(seq_merge)
......@@ -527,7 +536,7 @@ TEST_CASE(seq_merge)
EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2)
......@@ -535,7 +544,7 @@ TEST_CASE(seq_merge)
EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
t.check_conflicts(p, {c2, {i2}});
}
TEST_CASE(par_merge)
......@@ -565,17 +574,17 @@ TEST_CASE(par_merge)
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}});
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_par_merge)
......@@ -616,17 +625,17 @@ TEST_CASE(inner_par_merge)
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}});
t.check_conflicts(p, {c1, {i1}, c2, {i2}, {outer1}, {outer2}});
}
TEST_CASE(par_merge_multi_entry)
......@@ -658,17 +667,17 @@ TEST_CASE(par_merge_multi_entry)
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
t.check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
check_conflicts(p, {c1, {i1}, c2, {i2}});
t.check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_split1)
......@@ -696,7 +705,7 @@ TEST_CASE(inner_split1)
EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
check_conflicts(p, {c1, {i1}, {s1}, {s2}});
t.check_conflicts(p, {c1, {i1}, {s1}, {s2}});
}
TEST_CASE(inner_split2)
......@@ -722,7 +731,7 @@ TEST_CASE(inner_split2)
get_wait_for(t.get_stream(output),
{t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())}));
EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())}));
check_conflicts(p, {c1, {i1}, s1, s2});
t.check_conflicts(p, {c1, {i1}, s1, s2});
}
TEST_CASE(inception_resnet)
......@@ -745,7 +754,7 @@ TEST_CASE(inception_resnet)
get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output).empty());
check_conflicts(p, {c1, {i1}});
t.check_conflicts(p, {c1, {i1}});
}
TEST_CASE(inception1)
......@@ -866,7 +875,7 @@ TEST_CASE(inception1)
get_wait_for(t.get_stream(output),
{t.get_stream(i94), t.get_stream(i75), t.get_stream(i61), t.get_stream(i86)}));
check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}});
t.check_conflicts(p, {{i80, i86}, {i69, i75}, {i48, i54, i61}, {i94}});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment