"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "682968444d35efc2d0e5a203761042d63da0b7f4"
Commit f14cfd45 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_reshape' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents cb265820 239d50dc
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -46,14 +47,66 @@ struct reshape ...@@ -46,14 +47,66 @@ struct reshape
value attributes() const { return {{"require_std_shape", true}}; } value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
shape dyn_compute_shape(shape s0) const
{ {
check_shapes{inputs, *this}.has(1).standard(); auto dyn_dims = s0.dyn_dims();
int not_fixed_index = -1;
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(not_fixed_index == -1)
{
not_fixed_index = i;
}
else
{
MIGRAPHX_THROW("Reshape: Only support one non-fixed dynamic_dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("Reshape: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
if(dims[not_fixed_index] != 0 and dims[not_fixed_index] != -1)
{
MIGRAPHX_THROW("Reshape: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
for(std::size_t i = 0; i < dims.size(); ++i)
{
if(i == not_fixed_index)
{
output_dyn_dims.push_back(dyn_dims[not_fixed_index]);
}
else
{
std::size_t d = dims[i];
output_dyn_dims.push_back({d, d, 0});
}
}
return {s0.type(), output_dyn_dims};
}
template <class T>
shape static_compute_shape(std::vector<shape> inputs, T n_neg_dims) const
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++) for(std::size_t i = 0; i < dims.size(); i++)
{ {
...@@ -86,9 +139,26 @@ struct reshape ...@@ -86,9 +139,26 @@ struct reshape
return s; return s;
} }
argument compute(shape output_shape, std::vector<argument> args) const shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported"); check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
......
...@@ -2058,6 +2058,55 @@ TEST_CASE(reshape_shape) ...@@ -2058,6 +2058,55 @@ TEST_CASE(reshape_shape)
} }
} }
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
std::vector<migraphx::shape::dynamic_dimension> out_dyn_dims{};
for(std::size_t i = 0; i < new_shape.size(); ++i)
{
if(new_shape[i] == 0 or new_shape[i] == -1)
{
out_dyn_dims.push_back(input.dyn_dims().at(i));
}
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
}
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 20, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 10, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
...@@ -6114,47 +6114,76 @@ TEST_CASE(relu_dyn_test) ...@@ -6114,47 +6114,76 @@ TEST_CASE(relu_dyn_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reshape_test) TEST_CASE(reshape_test0)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> data(24);
std::iota(data.begin(), data.end(), -3); std::iota(data.begin(), data.end(), -3);
{ migraphx::program p;
migraphx::program p; auto* mm = p.get_main_module();
auto* mm = p.get_main_module(); auto l = mm->add_literal(migraphx::literal{a_shape, data});
auto l = mm->add_literal(migraphx::literal{a_shape, data}); std::vector<int64_t> new_shape = {8, 3, 1, 1};
std::vector<int64_t> new_shape = {8, 3, 1, 1}; mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); p.compile(migraphx::ref::target{});
p.compile(migraphx::ref::target{}); auto result = p.eval({}).back();
auto result = p.eval({}).back(); std::vector<float> results_vector{};
std::vector<float> results_vector(3); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); EXPECT(migraphx::verify_range(results_vector, data));
EXPECT(migraphx::verify_range(results_vector, data)); }
}
{ TEST_CASE(reshape_test1)
migraphx::program p; {
auto* mm = p.get_main_module(); migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
auto l = mm->add_literal(migraphx::literal{a_shape, data}); std::vector<float> data(24);
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::iota(data.begin(), data.end(), -3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); migraphx::program p;
p.compile(migraphx::ref::target{}); auto* mm = p.get_main_module();
auto result = p.eval({}).back(); auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<float> results_vector(3); std::vector<int64_t> new_shape = {1, 3, 4, 2};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
EXPECT(migraphx::verify_range(results_vector, data)); p.compile(migraphx::ref::target{});
} auto result = p.eval({}).back();
{ std::vector<float> results_vector{};
migraphx::program p; result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
auto* mm = p.get_main_module(); EXPECT(migraphx::verify_range(results_vector, data));
auto l = mm->add_literal(migraphx::literal{a_shape, data}); }
std::vector<int64_t> new_shape = {1, 3, 4, 2};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); TEST_CASE(reshape_test2)
p.compile(migraphx::ref::target{}); {
auto result = p.eval({}).back(); migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> results_vector(3); std::vector<float> data(24);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); std::iota(data.begin(), data.end(), -3);
EXPECT(migraphx::verify_range(results_vector, data)); migraphx::program p;
} auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data});
std::vector<int64_t> new_shape = {1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(reshape_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
p.compile(migraphx::ref::target{});
std::vector<float> data(48);
std::iota(data.begin(), data.end(), -3);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
params["X"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, data));
} }
TEST_CASE(reverse_test_axis0) TEST_CASE(reverse_test_axis0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment