Commit 1ccf131d authored by Paul's avatar Paul
Browse files

Format

parent 2389bc0d
...@@ -40,7 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -40,7 +40,7 @@ inline namespace MIGRAPHX_INLINE_NS {
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
template<class Message> template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
...@@ -195,9 +195,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -195,9 +195,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const auto& key = rv.get_key(); const auto& key = rv.get_key();
if(val.contains(key)) if(val.contains(key))
{ {
auto message = [&] { auto message = [&] { return op.name() + ": " + key + ": "; };
return op.name() + ": " + key + ": ";
};
auto vv = val.at(key).without_key(); auto vv = val.at(key).without_key();
if(vv.is_array()) if(vv.is_array())
{ {
......
...@@ -390,22 +390,28 @@ struct find_inner_broadcast ...@@ -390,22 +390,28 @@ struct find_inner_broadcast
struct find_dot_broadcast struct find_dot_broadcast
{ {
auto matcher() const { return match::name("dot")(match::all_of[match::inputs()](match::broadcast())); } auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a = ins->inputs()[0]; auto a = ins->inputs()[0];
auto b = ins->inputs()[1]; auto b = ins->inputs()[1];
if (a->get_operator().name() != b->get_operator().name()) if(a->get_operator().name() != b->get_operator().name())
return; return;
if (ins->get_shape().lens().size() < 3) if(ins->get_shape().lens().size() < 3)
return; return;
auto nbatch_axes = ins->get_shape().lens().size() - 2; auto nbatch_axes = ins->get_shape().lens().size() - 2;
// Find leading batch axes that are broadcasted // Find leading batch axes that are broadcasted
auto p = std::mismatch(a->get_shape().strides().begin(), a->get_shape().strides().begin()+nbatch_axes, b->get_shape().strides().begin(), b->get_shape().strides().begin()+nbatch_axes, [](auto astride, auto bstride) { auto p =
return astride == 0 and bstride == 0; std::mismatch(a->get_shape().strides().begin(),
}); a->get_shape().strides().begin() + nbatch_axes,
b->get_shape().strides().begin(),
b->get_shape().strides().begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a->get_shape().lens().begin(); auto naxes = p.first - a->get_shape().lens().begin();
std::vector<std::size_t> axes(naxes); std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -415,7 +421,7 @@ struct find_dot_broadcast ...@@ -415,7 +421,7 @@ struct find_dot_broadcast
auto delta = b_ins->get_shape().lens().size() - input->get_shape().lens().size(); auto delta = b_ins->get_shape().lens().size() - input->get_shape().lens().size();
auto squeeze_axes = axes; auto squeeze_axes = axes;
squeeze_axes.erase(squeeze_axes.end() - delta, squeeze_axes.end()); squeeze_axes.erase(squeeze_axes.end() - delta, squeeze_axes.end());
if (squeeze_axes.empty()) if(squeeze_axes.empty())
return input; return input;
return m.insert_instruction(ins, make_op("squeeze", {{"axes", squeeze_axes}}), input); return m.insert_instruction(ins, make_op("squeeze", {{"axes", squeeze_axes}}), input);
}; };
...@@ -423,7 +429,8 @@ struct find_dot_broadcast ...@@ -423,7 +429,8 @@ struct find_dot_broadcast
auto b1 = insert_sqeeze(b); auto b1 = insert_sqeeze(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1); auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), dot); auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", axes}}), dot);
auto broadcast = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze); auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), unsqueeze);
m.replace_instruction(ins, broadcast); m.replace_instruction(ins, broadcast);
} }
}; };
...@@ -433,7 +440,8 @@ struct find_concat_op ...@@ -433,7 +440,8 @@ struct find_concat_op
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::any_of[match::inputs()]( return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")), match::used_once())); match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
} }
template <class Iterator> template <class Iterator>
...@@ -452,7 +460,8 @@ struct find_concat_op ...@@ -452,7 +460,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op) static bool is_valid_op(const operation& op)
{ {
return contains({"broadcast", "multibroadcast"}, op.name()) or op.attributes().contains("pointwise"); return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -480,11 +489,11 @@ struct find_concat_op ...@@ -480,11 +489,11 @@ struct find_concat_op
op = b; op = b;
iaxis = 0; iaxis = 0;
} }
else if (op.name() == "multibroadcast") else if(op.name() == "multibroadcast")
{ {
shape bshape = (*start)->get_shape(); shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0]; auto input = (*start)->inputs()[0];
if (iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0) if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last}; return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}}); op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size(); auto delta = bshape.lens().size() - input->get_shape().lens().size();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment