Commit f14cfd45 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_reshape' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents cb265820 239d50dc
......@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -46,14 +47,66 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
int not_fixed_index = -1;
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(not_fixed_index == -1)
{
not_fixed_index = i;
}
else
{
MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{
if(i == not_fixed_index)
{
output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
}
else
{
check_shapes{inputs, *this}.has(1).standard();
std::size_t d = dims[i];
output_dyn_dims.push_back({d, d, 0});
}
}
return {s0.type(), output_dyn_dims};
}
template <class T>
shape static_compute_shape(std::vector<shape> inputs, T n_neg_dims) const
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++)
{
......@@ -86,9 +139,26 @@ struct reshape
return s;
}
argument compute(shape output_shape, std::vector<argument> args) const
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
......@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported");
check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
}
......
......@@ -2058,6 +2058,55 @@ TEST_CASE(reshape_shape)
}
}
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
std::vector<migraphx::shape::dynamic_dimension> out_dyn_dims{};
for(std::size_t i = 0; i < new_shape.size(); ++i)
{
if(new_shape[i] == 0 or new_shape[i] == -1)
{
out_dyn_dims.push_back(input.dyn_dims().at(i));
}
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
}
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 20, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 10, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(rnn)
{
{
......
......@@ -6114,12 +6114,11 @@ TEST_CASE(relu_dyn_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reshape_test)
TEST_CASE(reshape_test0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
......@@ -6127,11 +6126,16 @@ TEST_CASE(reshape_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
}
{
}
TEST_CASE(reshape_test1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
......@@ -6139,22 +6143,47 @@ TEST_CASE(reshape_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
}
{
}
TEST_CASE(reshape_test2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3);
migraphx::program p;
auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 3, 4, 2};
std::vector<int64_t> new_shape = {1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(3);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(reshape_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
p.compile(migraphx::ref::target{});
std::vector<float> data(48);
std::iota(data.begin(), data.end(), -3);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
}
}
TEST_CASE(reverse_test_axis0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment