Commit 59f6009b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

initial testing workaround

parent e46a6a52
...@@ -627,6 +627,54 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +627,54 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("multibroadcast")(match::all_of[match::outputs()](match::name("transpose").bind("trans_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto input = ins->inputs().front();
auto input_lens = input->get_shape().lens();
auto trans_ins = r.instructions["trans_ins"];
auto trans_shape = trans_ins->get_shape();
auto permutation = trans_ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
// {1, 3072, 1024}
// unsqueeze axes {0, 1}
// need unsqueeze to be {0, 2}
// transpose permutation {0, 2, 1}
std::vector<int64_t> unsqueeze_axes{};
auto unsqueeze_count = -1;
for(auto dim : ins_lens)
{
unsqueeze_count++;
if(contains(input_lens, dim))
{
continue;
}
unsqueeze_axes.push_back(unsqueeze_count);
}
for(auto i = 0; i < unsqueeze_axes.size(); i++)
{
if(permutation[i] > unsqueeze_axes[i])
{
unsqueeze_axes[i] += permutation[i] - i;
}
}
auto unsqueeze_ins = m.insert_instruction(trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
auto mbcast_ins = m.insert_instruction(trans_ins, make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}), unsqueeze_ins);
m.replace_instruction(trans_ins, mbcast_ins);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -793,6 +841,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -793,6 +841,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{}, find_reshaper{},
find_reshape_cont{}, find_reshape_cont{},
find_transpose{}, find_transpose{},
find_broadcast_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{}, find_concat_multibroadcasts{},
find_nested_convert{}, find_nested_convert{},
......
...@@ -201,6 +201,31 @@ TEST_CASE(reshape_transpose) ...@@ -201,6 +201,31 @@ TEST_CASE(reshape_transpose)
EXPECT(std::distance(m.begin(), m.end()) == n); EXPECT(std::distance(m.begin(), m.end()) == n);
} }
TEST_CASE(broadcast_transpose)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb);
mm->add_return({t1});
run_pass(*mm);
migraphx::program p2;
mm = p2.get_main_module();
l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l);
mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}), unsqueeze);
mm->add_return({mb});
EXPECT(p == p2);
// EXPECT(not mm->get_output_shapes().back().standard());
// EXPECT(mm->get_output_shapes().back().transposed());
// EXPECT(std::distance(mm->begin(), mm->end()) == 3);
// auto result = p.eval({}).back();
// EXPECT(result != get_2x2());
}
TEST_CASE(transpose_contiguous) TEST_CASE(transpose_contiguous)
{ {
migraphx::module m; migraphx::module m;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment