"include/vscode:/vscode.git/clone" did not exist on "5ae304df0a62a94488a9043370207dbcc113ecca"
Unverified Commit a46f378e authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Handle broadcasts across dot and concat (#1689) (#1731)



Improves the constant propagation for bert models. Larger batch size no longer use as large of constants.  Also improves the speed of model compilation
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent ed6542ee
...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens) const std::vector<std::size_t>& lens,
Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = lens.size();
...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!"); MIGRAPHX_THROW(m() + "value out of range!");
} }
} }
} }
...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal( if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{})) min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
else else
{ {
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{})) if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{ {
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!"); MIGRAPHX_THROW(m() + "attribute out of range!");
} }
} }
} }
...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto vv = val.at(key).without_key(); auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens); auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens); auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -361,30 +361,118 @@ struct find_inner_broadcast ...@@ -361,30 +361,118 @@ struct find_inner_broadcast
{ {
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); } auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto broadcasts = ins->inputs(); auto broadcasts = ins->inputs();
if(broadcasts.empty()) if(broadcasts.empty())
return; return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
broadcasts.end(), broadcasts.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [&](instruction_ref i) {
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { auto input = i->inputs().front();
return i->get_shape() != inputs.front()->get_shape() and if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().elements() != 1; i->get_shape().lens().size() > 1)
})) return m.insert_instruction(i, make_op("squeeze"), input);
return; return input;
});
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar(); std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
}); if(i->get_shape().scalar())
if(b_it == broadcasts.end()) return 2;
b_it = broadcasts.begin(); else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
} }
}; };
...@@ -393,7 +481,8 @@ struct find_concat_op ...@@ -393,7 +481,8 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once())); match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -412,7 +501,8 @@ struct find_concat_op ...@@ -412,7 +501,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op) static bool is_valid_op(const operation& op)
{ {
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -440,6 +530,16 @@ struct find_concat_op ...@@ -440,6 +530,16 @@ struct find_concat_op
op = b; op = b;
iaxis = 0; iaxis = 0;
} }
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats; std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++) for(std::size_t i = 0; i < x->inputs().size(); i++)
...@@ -1260,6 +1360,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1260,6 +1360,7 @@ void simplify_algebra::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
find_inner_broadcast{}, find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
find_add_convs{}, find_add_convs{},
......
...@@ -613,6 +613,60 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -613,6 +613,60 @@ TEST_CASE(simplify_inner_broadcast_scalar)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(b, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m1.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xb = m1.add_instruction(b, x);
auto yb = m1.add_instruction(mb, y);
auto sum = m1.add_instruction(migraphx::make_op("add"), xb, yb);
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {24}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {24, 1, 1}});
auto xs = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto ys = m2.add_instruction(migraphx::make_op("squeeze"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), xs, ys);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_add_conv1) TEST_CASE(simplify_add_conv1)
{ {
migraphx::module m; migraphx::module m;
...@@ -3003,6 +3057,38 @@ TEST_CASE(reorder_slice_ins_deps) ...@@ -3003,6 +3057,38 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT(m == create_module()); EXPECT(m == create_module());
} }
TEST_CASE(dot_broadcast_different_rank)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m1.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}), x);
auto yb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 768, 3072}}}), y);
auto dot = m1.add_instruction(migraphx::make_op("dot"), xb, yb);
m1.add_return({dot});
};
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {768}});
auto y = m2.add_parameter("y", {migraphx::shape::float_type, {768, 3072}});
auto xb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), x);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {768, 3072}}}), y);
auto dot = m2.add_instruction(migraphx::make_op("dot"), xb, yb);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 3072}}}), dot);
m2.add_return({broadcast});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(dot_fusion_reshape) TEST_CASE(dot_fusion_reshape)
{ {
migraphx::module m1; migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment