Commit 6218f286 authored by charlie's avatar charlie
Browse files

Detach from dyn_squeeze and dyn_unsqueeze

parent 56b5c3be
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -55,90 +54,52 @@ struct squeeze ...@@ -55,90 +54,52 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic()) auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}}; MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { }
return not contains(one_dyn_dims, input_shape.dyn_dims()[axis]); std::vector<std::size_t> new_lens;
})) std::vector<std::size_t> new_strides;
{ if(axes.empty())
MIGRAPHX_THROW( {
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}"); for(auto i : range(old_lens.size()))
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
for(auto i : range(input_shape.ndim()))
{
auto dd = input_shape.dyn_dims()[i];
if(not contains(one_dyn_dims, dd))
{
dyn_dims.push_back(dd);
}
}
}
else
{ {
for(auto i : range(input_shape.ndim())) if(old_lens[i] != 1)
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) new_lens.push_back(old_lens[i]);
{ new_strides.push_back(old_strides[i]);
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
} }
} }
return {input_shape.type(), dyn_dims};
} }
else else
{ {
auto type = input_shape.type(); for(auto i : range(old_lens.size()))
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{ {
for(auto i : range(old_lens.size())) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
if(old_lens[i] != 1) new_lens.push_back(old_lens[i]);
{ new_strides.push_back(old_strides[i]);
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
} }
} }
else }
{ if(new_lens.empty())
for(auto i : range(old_lens.size())) {
{ return shape{type};
if(std::find(axes.begin(), axes.end(), i) == axes.end()) }
{ else
new_lens.push_back(old_lens[i]); {
new_strides.push_back(old_strides[i]); return shape{type, new_lens, new_strides};
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
} }
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,20 +29,11 @@ ...@@ -29,20 +29,11 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -65,89 +56,63 @@ struct unsqueeze ...@@ -65,89 +56,63 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type();
if(input_shape.dynamic()) auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{ {
if(not steps.empty()) if(old_lens.size() == 1 and old_lens.front() == 1)
{ return shape{type, old_lens};
MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute"); else
} MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
} }
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size()) if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis"); MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size); std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin(); std::int64_t step = 1;
if(axis_idx < axes.size()) if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
std::int64_t step = 1; if((old_lens[p] % step) != 0)
if(axis_idx < steps.size()) MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
step = steps[axis_idx]; old_lens[p] /= step;
if(step == 0) new_strides[i] = old_strides[p] * old_lens[p];
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
} }
else else
{ {
new_lens[i] = old_lens[p]; if(step != 1)
new_strides[i] = old_strides[p++]; MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
} }
} }
return shape{type, new_lens, new_strides}; else
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
} }
return shape{type, new_lens, new_strides};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -2036,30 +2036,6 @@ TEST_CASE(test_squeeze_all) ...@@ -2036,30 +2036,6 @@ TEST_CASE(test_squeeze_all)
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1); expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
} }
TEST_CASE(test_squeeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze", {{"axes", {-2, -4}}}), s1);
}
TEST_CASE(test_squeeze_transpose) TEST_CASE(test_squeeze_transpose)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}}; migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}};
...@@ -2101,30 +2077,6 @@ TEST_CASE(test_unsqueeze) ...@@ -2101,30 +2077,6 @@ TEST_CASE(test_unsqueeze)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
TEST_CASE(test_unsqueeze_step) TEST_CASE(test_unsqueeze_step)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
...@@ -2156,27 +2108,13 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis) ...@@ -2156,27 +2108,13 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis1) TEST_CASE(test_unsqueeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}}; migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
} }
TEST_CASE(test_unsqueeze_negative_axis2)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 3, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis3)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 5, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-3}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar) TEST_CASE(test_unsqueeze_scalar)
{ {
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}}; migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
......
...@@ -7086,25 +7086,6 @@ TEST_CASE(squeeze_test) ...@@ -7086,25 +7086,6 @@ TEST_CASE(squeeze_test)
} }
} }
TEST_CASE(squeeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0);
p.compile(migraphx::ref::target{});
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(step_test) TEST_CASE(step_test)
{ {
{ {
...@@ -7404,25 +7385,6 @@ TEST_CASE(unsqueeze_test) ...@@ -7404,25 +7385,6 @@ TEST_CASE(unsqueeze_test)
} }
} }
TEST_CASE(unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::ref::target{});
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 3, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(where_test) TEST_CASE(where_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment