Commit 2389bc0d authored by Paul's avatar Paul
Browse files

Handle broadcasts across dot and concat

parent f6e22d56
......@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
*
* See normalize_attribute.hpp for explaining the options.
*/
template<class Message>
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
const std::vector<std::size_t>& lens)
const std::vector<std::size_t>& lens,
Message m)
{
std::vector<int64_t> result(vec);
int64_t n_rank = lens.size();
......@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), max_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: value out of range!");
MIGRAPHX_THROW(m() + "value out of range!");
}
}
}
......@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if(not std::equal(
min_vals.begin(), min_vals.end(), result.begin(), std::less_equal<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
else
{
if(not std::equal(result.begin(), result.end(), min_vals.begin(), std::less<>{}))
{
MIGRAPHX_THROW("TUNE_VECTOR: attribute out of range!");
MIGRAPHX_THROW(m() + "attribute out of range!");
}
}
}
......@@ -193,6 +195,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key();
if(val.contains(key))
{
auto message = [&] {
return op.name() + ": " + key + ": ";
};
auto vv = val.at(key).without_key();
if(vv.is_array())
{
......@@ -202,7 +207,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>();
}
auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens);
auto result = tune_attribute(vec, axes, rv.without_key(), lens, message);
val[key] = result;
op.from_value(val);
val = op.to_value();
......@@ -211,7 +216,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
{
auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens);
auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message);
val[key] = result.front();
op.from_value(val);
val = op.to_value();
......
......@@ -388,12 +388,52 @@ struct find_inner_broadcast
}
};
struct find_dot_broadcast
{
auto matcher() const { return match::name("dot")(match::all_of[match::inputs()](match::broadcast())); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if (a->get_operator().name() != b->get_operator().name())
return;
if (ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
// Find leading batch axes that are broadcasted
auto p = std::mismatch(a->get_shape().strides().begin(), a->get_shape().strides().begin()+nbatch_axes, b->get_shape().strides().begin(), b->get_shape().strides().begin()+nbatch_axes, [](auto astride, auto bstride) {
return astride == 0 and bstride == 0;
});
auto naxes = p.first - a->get_shape().lens().begin();
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_sqeeze = [&](instruction_ref b_ins) {
auto input = b_ins->inputs()[0];
auto delta = b_ins->get_shape().lens().size() - input->get_shape().lens().size();
auto squeeze_axes = axes;
squeeze_axes.erase(squeeze_axes.end() - delta, squeeze_axes.end());
if (squeeze_axes.empty())
return input;
return m.insert_instruction(ins, make_op("squeeze", {{"axes", squeeze_axes}}), input);
};
auto a1 = insert_sqeeze(a);
auto b1 = insert_sqeeze(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), dot);
auto broadcast = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
m.replace_instruction(ins, broadcast);
}
};
struct find_concat_op
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")), match::used_once()));
}
template <class Iterator>
......@@ -412,7 +452,7 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or op.attributes().contains("pointwise");
}
void apply(module& m, const match::matcher_result& r) const
......@@ -440,6 +480,16 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if (op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if (iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment