Unverified Commit 48cc33e4 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic ref squeeze and unsqueeze (#1426)

Extends unsqueeze and squeeze to work for dynamic input shapes
Does not handle the steps parameter
Adds some additional negative axes shape tests
parent be70702d
......@@ -29,6 +29,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -54,52 +55,86 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
if(input_shape.dynamic())
{
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != one_dyn_dim;
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != one_dyn_dim; });
}
else
{
if(old_lens[i] != 1)
for(auto i : range(input_shape.ndim()))
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
}
}
return {input_shape.type(), dyn_dims};
}
else
{
for(auto i : range(old_lens.size()))
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty())
{
for(auto i : range(old_lens.size()))
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
if(old_lens[i] != 1)
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
else
{
for(auto i : range(old_lens.size()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
}
if(new_lens.empty())
{
return shape{type};
}
else
{
return shape{type, new_lens, new_strides};
}
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -29,11 +29,20 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze
{
std::vector<int64_t> axes;
......@@ -56,63 +65,89 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
if(input_shape.dynamic())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
if(not steps.empty())
{
MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size();
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0;
for(auto i : range(new_size))
{
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
}
else
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
return shape{type, new_lens, new_strides};
}
return shape{type, new_lens, new_strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
......
......@@ -6012,6 +6012,26 @@ def squeeze_unsqueeze_test():
return ([node, node2], [x], [y])
@onnx_test
def squeeze_unsqueeze_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[1, None, 1, 1, None, 1])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT,
[1, 1, None, 1, None, 1])
node = onnx.helper.make_node('Squeeze',
inputs=['0'],
axes=[0, 2, 3, 5],
outputs=['1'])
node2 = onnx.helper.make_node('Unsqueeze',
inputs=['1'],
axes=[0, 1, 3, 5],
outputs=['2'])
return ([node, node2], [x], [y])
@onnx_test
def sub_bcast_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -5802,6 +5802,29 @@ TEST_CASE(squeeze_unsqueeze_test)
EXPECT(p == prog);
}
TEST_CASE(squeeze_unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 4, 0}, {1, 1, 0}, {1, 1, 0}, {1, 4, 0}, {1, 1, 0}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto ret = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), c1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(squeeze_axes_input_test)
{
migraphx::program p;
......
......@@ -2139,6 +2139,30 @@ TEST_CASE(test_squeeze_all)
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
}
TEST_CASE(test_squeeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
expect_shape(s3, migraphx::make_op("squeeze", {{"axes", {-2, -4}}}), s1);
}
TEST_CASE(test_squeeze_transpose)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 4, 1}, {4, 1, 4}};
......@@ -2180,6 +2204,30 @@ TEST_CASE(test_unsqueeze)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
TEST_CASE(test_unsqueeze_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
......@@ -2211,13 +2259,27 @@ TEST_CASE(test_unsqueeze_mismatch_step_axis)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis)
TEST_CASE(test_unsqueeze_negative_axis1)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 1, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis2)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 5, 3, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis3)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 5, 3}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-3}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar)
{
migraphx::shape s1{migraphx::shape::float_type, {1}, {0}};
......
......@@ -7294,6 +7294,25 @@ TEST_CASE(squeeze_test)
}
}
TEST_CASE(squeeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0);
p.compile(migraphx::ref::target{});
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 1, 3, 1, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(step_test)
{
{
......@@ -7592,6 +7611,25 @@ TEST_CASE(unsqueeze_test)
}
}
TEST_CASE(unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::ref::target{});
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 3, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(where_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment