Commit aeea7344 authored by charlie's avatar charlie
Browse files

Progress

parent e8640da6
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
...@@ -65,8 +66,22 @@ struct find_static_2in_broadcasts ...@@ -65,8 +66,22 @@ struct find_static_2in_broadcasts
}; };
/** /**
* Simplify slice with variable `starts` and `ends` to the constant version if * Simplify slice with 2 inputs to the 1 input version if inputs[1] is constant.
* the `input_starts` and `input_ends` inputs are constant. * From:
* slice(data, constant_input); two attributes set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
{
};
/**
* Simplify slice with 3 inputs to the 1 input version if inputs[1:2] are constant.
* From:
* slice(data, constant_input1, constant_input2); one attribute set
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_3in_slice struct find_const_3in_slice
{ {
...@@ -81,27 +96,52 @@ struct find_const_3in_slice ...@@ -81,27 +96,52 @@ struct find_const_3in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); auto slice_val = ins->get_operator().to_value();
argument ends_arg = inputs.at(2)->eval(); auto slice_op = any_cast<op::slice>(ins->get_operator());
if(not starts_arg.empty() and not ends_arg.empty()) auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
if(set_attrs == slice_op.axes_only)
{ {
std::vector<int64_t> starts_vec; // slice(data, starts, ends)
std::vector<int64_t> ends_vec; inputs.at(1)->eval().visit(
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); }); [&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); }); inputs.at(2)->eval().visit(
auto slice_val = ins->get_operator().to_value(); [&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto axes_vec = slice_val.at("axes").to_vector<int64_t>(); axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction( }
ins, else if(set_attrs == slice_op.ends_only)
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}), {
inputs.at(0)); // slice(data, starts, axes)
inputs.at(1)->eval().visit(
[&](auto output) { starts_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
ends_vec = slice_val.at("ends").to_vector<int64_t>();
} }
else
{
// slice(data, ends, axes)
inputs.at(1)->eval().visit(
[&](auto output) { ends_vec.assign(output.begin(), output.end()); });
inputs.at(2)->eval().visit(
[&](auto output) { axes_vec.assign(output.begin(), output.end()); });
starts_vec = slice_val.at("starts").to_vector<int64_t>();
}
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
} }
}; };
/** /**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if * Simplify slice with 4 inputs to the 1 input version if inputs[1:3] are constant.
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant. * From:
* slice(data, constant_starts, constant_ends, constant_axes)
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/ */
struct find_const_4in_slice struct find_const_4in_slice
{ {
...@@ -117,9 +157,9 @@ struct find_const_4in_slice ...@@ -117,9 +157,9 @@ struct find_const_4in_slice
{ {
auto ins = mr.result; auto ins = mr.result;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval(); argument starts_arg = inputs.at(1)->eval(false);
argument ends_arg = inputs.at(2)->eval(); argument ends_arg = inputs.at(2)->eval(false);
argument axes_arg = inputs.at(3)->eval(); argument axes_arg = inputs.at(3)->eval(false);
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty()) if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{ {
std::vector<int64_t> starts_vec; std::vector<int64_t> starts_vec;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment