Commit 19072784 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add change and test

parent 1af66a1c
......@@ -182,6 +182,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
else
{
auto common = common_shape(to_shapes(inputs));
std::cout << common << std::endl;
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
......
......@@ -521,6 +521,19 @@ struct find_inner_broadcast
}) < (lens.size() - 1);
}))
return;
auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
for(auto i = 0; i < broadcasts.front()->get_shape().strides().size(); i++)
{
for(auto j = 0; j < broadcasts.size(); j++)
{
if(broadcasts[j]->get_shape().strides()[i] == 0)
common_axis[i]++;
}
}
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common){ return num_common > 1 ; }) == common_axis.end())
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
......
......@@ -639,6 +639,32 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_different_dims2)
{
auto b = migraphx::op::multibroadcast{{1, 1024, 3072}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1024, 3072}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {1, 1024, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
m1.debug_print();
// migraphx::module m2;
// {
// auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1024, 768}});
// auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
// auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
// auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
// auto sumb = m2.add_instruction(b, sum);
// m2.add_instruction(pass_op{}, sumb);
// }
// EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment