Unverified Commit 294a0e66 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Ensure the slices are in order before removing them in concat (#528)



* Ensure the slices are in order before removing them in concat

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 557618ba
......@@ -445,6 +445,13 @@ struct find_split_concat
std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); });
if(std::distance(it, args.end()) < splits.size())
return;
// If the slices are not in order then stop
if(not std::is_sorted(it, it + splits.size(), [](instruction_ref x, instruction_ref y) {
auto xop = any_cast<op::slice>(x->get_operator());
auto yop = any_cast<op::slice>(y->get_operator());
return std::tie(xop.starts, xop.ends) < std::tie(yop.starts, yop.ends);
}))
return;
*it = splits.front()->inputs().front();
args.erase(std::next(it), it + splits.size());
......
......@@ -586,6 +586,89 @@ TEST_CASE(simplify_rsqrt_multi_use)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_slice_concat)
{
auto s = migraphx::shape{migraphx::shape::float_type, {256}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {128}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {128}, {256}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {128}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {128}, {256}}, y);
auto concat =
p1.add_instruction(migraphx::op::concat{0}, xslice1, xslice2, yslice1, yslice2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto concat = p2.add_instruction(migraphx::op::concat{0}, x, y);
p2.add_instruction(pass_op{}, concat);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_slice_concat_non_uniform)
{
auto s = migraphx::shape{migraphx::shape::float_type, {256}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto xslice3 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto yslice3 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto concat = p1.add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto concat = p2.add_instruction(migraphx::op::concat{0}, x, y);
p2.add_instruction(pass_op{}, concat);
}
EXPECT(p1 == p2);
}
TEST_CASE(simplify_slice_concat_flipped)
{
auto s = migraphx::shape{migraphx::shape::float_type, {256}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto xslice3 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto yslice3 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto concat = p1.add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
p1.add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1 == p2);
}
TEST_CASE(simplify_split_add_relu)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment