Commit 954bca60 authored by charlie's avatar charlie
Browse files

draft of matchers

parent 34208e6d
......@@ -74,6 +74,50 @@ struct find_static_2in_broadcasts
*/
struct find_const_2in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == slice_op.ends_axes)
{
// slice(data, starts)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_vec = slice_op.ends;
axes_vec = slice_op.axes;
}
else if(set_attrs == slice_op.starts_axes)
{
// slice(data, ends)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
axes_vec = slice_op.axes;
}
else
{
// slice(data, axes)
inputs.at(1)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_op.starts;
ends_vec = slice_op.ends;
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
};
/**
......@@ -96,7 +140,6 @@ struct find_const_3in_slice
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_val = ins->get_operator().to_value();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
......@@ -109,7 +152,7 @@ struct find_const_3in_slice
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_vec = slice_val.at("axes").to_vector<int64_t>();
axes_vec = slice_op.axes;
}
else if(set_attrs == slice_op.ends_only)
{
......@@ -118,7 +161,7 @@ struct find_const_3in_slice
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_val.at("ends").to_vector<int64_t>();
ends_vec = slice_op.ends;
}
else
{
......@@ -127,7 +170,7 @@ struct find_const_3in_slice
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_val.at("starts").to_vector<int64_t>();
starts_vec = slice_op.starts;
}
m.replace_instruction(
ins,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment