Commit 7cd27a5f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a unit test to verify the fix

parent 83ae8902
......@@ -101,4 +101,42 @@ TEST_CASE(after_param_broadcast)
EXPECT(not m.get_output_shapes().back().broadcasted());
}
TEST_CASE(two_transpose_gather)
{
auto create_module = []
{
migraphx::module m;
auto data = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto sd = m.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), td);
auto bd = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), bd, ind);
m.add_return({r});
return m;
};
auto m = create_module();
run_pass(m);
auto create_cont_module = []
{
migraphx::module m;
auto data = m.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto ind = m.add_parameter("ind", {migraphx::shape::float_type, {2, 3}});
auto td = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto bd = m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
auto cbd = m.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m.add_return({r});
return m;
};
EXPECT(m == create_cont_module());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment