Commit 27b0820d authored by charlie's avatar charlie
Browse files

initial

parent 17abf67e
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,52 +55,90 @@ struct squeeze ...@@ -54,52 +55,90 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); if(input_shape.dynamic())
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
} if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
std::vector<std::size_t> new_lens; return not contains(one_dyn_dims, input_shape.dyn_dims()[axis]);
std::vector<std::size_t> new_strides; }))
if(axes.empty()) {
{ MIGRAPHX_THROW(
for(auto i : range(old_lens.size())) "SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{ {
if(old_lens[i] != 1) for(auto i : range(input_shape.ndim()))
{ {
new_lens.push_back(old_lens[i]); auto dd = input_shape.dyn_dims()[i];
new_strides.push_back(old_strides[i]); if(not contains(one_dyn_dims, dd))
{
dyn_dims.push_back(dd);
}
} }
} }
} else
else
{
for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) for(auto i : range(input_shape.ndim()))
{ {
new_lens.push_back(old_lens[i]); if(std::find(axes.begin(), axes.end(), i) == axes.end())
new_strides.push_back(old_strides[i]); {
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
} }
} }
} return {input_shape.type(), dyn_dims};
if(new_lens.empty())
{
return shape{type};
} }
else else
{ {
return shape{type, new_lens, new_strides}; auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
{
if(old_lens[i] != 1)
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
else
{
for(auto i : range(old_lens.size()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment