Commit 12281355 authored by Paul's avatar Paul
Browse files

Fix issues

parent 1ccf131d
......@@ -405,32 +405,42 @@ struct find_dot_broadcast
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a->get_shape().strides().begin(),
a->get_shape().strides().begin() + nbatch_axes,
b->get_shape().strides().begin(),
b->get_shape().strides().begin() + nbatch_axes,
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a->get_shape().lens().begin();
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_sqeeze = [&](instruction_ref b_ins) {
auto insert_squeeze = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
auto delta = b_ins->get_shape().lens().size() - input->get_shape().lens().size();
auto squeeze_axes = axes;
squeeze_axes.erase(squeeze_axes.end() - delta, squeeze_axes.end());
if(squeeze_axes.empty())
return input;
return m.insert_instruction(ins, make_op("squeeze", {{"axes", squeeze_axes}}), input);
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes, b_ins->get_shape().lens().end());
if (b_ins->name() == "multibroadcast")
{
return m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if (b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_sqeeze(a);
auto b1 = insert_sqeeze(b);
auto a1 = insert_squeeze(a);
auto b1 = insert_squeeze(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), dot);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};
......@@ -1319,6 +1329,7 @@ void simplify_algebra::apply(module& m) const
{
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
......
......@@ -3003,6 +3003,33 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module());
}
TEST_CASE(dot_broadcast_different_rank)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto yb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_fusion_reshape)
{
migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment