Commit 1ccf131d authored by Paul's avatar Paul
Browse files

Format

parent 2389bc0d
......@@ -40,7 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS {
*
* See normalize_attribute.hpp for explaining the options.
*/
template<class Message>
template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
......@@ -195,9 +195,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key();
if(val.contains(key))
{
auto message = [&] {
return op.name() + ": " + key + ": ";
};
auto message = [&] { return op.name() + ": " + key + ": "; };
auto vv = val.at(key).without_key();
if(vv.is_array())
{
......
......@@ -390,22 +390,28 @@ struct find_inner_broadcast
struct find_dot_broadcast
{
auto matcher() const { return match::name("dot")(match::all_of[match::inputs()](match::broadcast())); }
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if (a->get_operator().name() != b->get_operator().name())
if(a->get_operator().name() != b->get_operator().name())
return;
if (ins->get_shape().lens().size() < 3)
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
// Find leading batch axes that are broadcasted
auto p = std::mismatch(a->get_shape().strides().begin(), a->get_shape().strides().begin()+nbatch_axes, b->get_shape().strides().begin(), b->get_shape().strides().begin()+nbatch_axes, [](auto astride, auto bstride) {
return astride == 0 and bstride == 0;
});
auto p =
std::mismatch(a->get_shape().strides().begin(),
a->get_shape().strides().begin() + nbatch_axes,
b->get_shape().strides().begin(),
b->get_shape().strides().begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a->get_shape().lens().begin();
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
......@@ -415,7 +421,7 @@ struct find_dot_broadcast
auto delta = b_ins->get_shape().lens().size() - input->get_shape().lens().size();
auto squeeze_axes = axes;
squeeze_axes.erase(squeeze_axes.end() - delta, squeeze_axes.end());
if (squeeze_axes.empty())
if(squeeze_axes.empty())
return input;
return m.insert_instruction(ins, make_op("squeeze", {{"axes", squeeze_axes}}), input);
};
......@@ -423,7 +429,8 @@ struct find_dot_broadcast
auto b1 = insert_sqeeze(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), dot);
auto broadcast = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
m.replace_instruction(ins, broadcast);
}
};
......@@ -433,7 +440,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
}
template <class Iterator>
......@@ -452,7 +460,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return contains({"broadcast", "multibroadcast"}, op.name()) or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
}
void apply(module& m, const match::matcher_result& r) const
......@@ -480,11 +489,11 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if (op.name() == "multibroadcast")
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if (iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment