Commit da5c6162 authored by Paul's avatar Paul
Browse files

Fix mathcer and add unit test

parent fb0ade6a
......@@ -871,7 +871,7 @@ struct find_broadcast_reshaper
{
auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")));
return match::name(reshaper_names())(match::args(broadcast.bind("broadcast")));
return match::name(reshaper_names())(match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast"))));
}
void apply(module& m, const match::matcher_result& r) const
......
......@@ -1463,4 +1463,27 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(broadcast_transpose_reshape)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto broadcast = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x);
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast);
auto contiguous = m1.add_instruction(migraphx::make_op("contiguous"), transpose);
auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}), contiguous);
m1.add_return({reshape});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto broadcast = m2.add_instruction(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze);
m2.add_return({broadcast});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment