"docs/source/misc/op_evo_examples.rst" did not exist on "dbb2434f5d2d976be26b594342a68cb46619ecea"
Commit 27b0820d authored by charlie's avatar charlie
Browse files

initial

parent 17abf67e
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,14 +55,51 @@ struct squeeze ...@@ -54,14 +55,51 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic())
{
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not contains(one_dyn_dims, input_shape.dyn_dims()[axis]);
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
for(auto i : range(input_shape.ndim()))
{
auto dd = input_shape.dyn_dims()[i];
if(not contains(one_dyn_dims, dd))
{
dyn_dims.push_back(dd);
}
}
}
else
{
for(auto i : range(input_shape.ndim()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides; std::vector<std::size_t> new_strides;
...@@ -96,10 +134,11 @@ struct squeeze ...@@ -96,10 +134,11 @@ struct squeeze
return shape{type, new_lens, new_strides}; return shape{type, new_lens, new_strides};
} }
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment