Commit 13452c07 authored by Paul's avatar Paul
Browse files

Format

parent 12281355
...@@ -404,7 +404,7 @@ struct find_dot_broadcast ...@@ -404,7 +404,7 @@ struct find_dot_broadcast
return; return;
if(ins->get_shape().lens().size() < 3) if(ins->get_shape().lens().size() < 3)
return; return;
auto nbatch_axes = ins->get_shape().lens().size() - 2; auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides(); const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides(); const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted // Find leading batch axes that are broadcasted
...@@ -420,18 +420,20 @@ struct find_dot_broadcast ...@@ -420,18 +420,20 @@ struct find_dot_broadcast
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
auto insert_squeeze = [&](instruction_ref b_ins) -> instruction_ref { auto insert_squeeze = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0]; auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes, b_ins->get_shape().lens().end()); std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
if (b_ins->name() == "multibroadcast") b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{ {
return m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), input); return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
} }
else if (b_ins->name() == "broadcast") else if(b_ins->name() == "broadcast")
{ {
auto v = b_ins->get_operator().to_value(); auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes; auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input); return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
} }
assert(false); assert(false);
return m.end(); return m.end();
......
...@@ -3007,10 +3007,12 @@ TEST_CASE(dot_broadcast_different_rank) ...@@ -3007,10 +3007,12 @@ TEST_CASE(dot_broadcast_different_rank)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}}); auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}}); auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x); auto xb = m1.add_instruction(
auto yb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y); migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb); auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot}); m1.add_return({dot});
}; };
...@@ -3019,10 +3021,13 @@ TEST_CASE(dot_broadcast_different_rank) ...@@ -3019,10 +3021,13 @@ TEST_CASE(dot_broadcast_different_rank)
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}}); auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x); auto xb =
auto yb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb); auto yb =
auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast}); m2.add_return({broadcast});
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment