Unverified Commit f25606f9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

2 Input Reshape `ref` implementation (#2304)

parent a7200610
......@@ -36,6 +36,22 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct reshape
{
std::vector<int64_t> dims;
......@@ -215,13 +231,15 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
check_shapes{inputs, *this, true}.has(1, 2);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs.front();
if(inputs.size() == 1)
{
if(s0.dynamic())
{
return dyn_compute_shape(s0);
......@@ -231,10 +249,17 @@ struct reshape
return static_compute_shape(inputs, n_neg_dims);
}
}
else
{
return inputs.back();
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(dyn_out.computed_shape.standard());
if(args.size() == 1)
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) {
......@@ -242,6 +267,21 @@ struct reshape
});
return result;
}
else
{
// 2 arg
if(args[0].get_shape().elements() != args[1].get_shape().elements())
{
MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " +
std::to_string(args[0].get_shape().elements()) +
" Output buffer: " + std::to_string(args[1].get_shape().elements()));
}
visit_all(args[1], args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return args[1];
}
}
};
} // namespace op
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -45,16 +45,26 @@ struct parse_reshape : op_parser<parse_reshape>
{
literal s = parser.parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
if(args.size() == 2)
else
{
// 2 inputs
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
if(s.empty())
{
// arg[1] not eval-able
auto alloc_ins = info.add_instruction(
make_op("allocate", {{"buf_type", args[0]->get_shape().type()}}), args[1]);
return info.add_instruction(make_op("reshape"), args[0], alloc_ins);
}
else
{
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
}
}
};
} // namespace onnx
......
......@@ -6065,6 +6065,24 @@ def reshape_non_standard_test():
return ([trans, res], [x], [y])
@onnx_test()
def reshape_variable_input_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])
@onnx_test()
def reshape_variable_input_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None, 6])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])
@onnx_test()
def resize_downsample_f_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
......
......@@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test)
EXPECT(p == prog);
}
TEST_CASE(reshape_variable_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
auto prog = optimize_onnx("reshape_variable_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reshape_variable_input_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {3, 3}}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
mm->add_return({reshape});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_c_test)
{
migraphx::program p;
......
reshape_variable_input_test:p

0
12"Reshapereshape_variable_input_testZ
0



Z
1

b
2


B
\ No newline at end of file
......@@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_dyn_shape)
TEST_CASE(reshape_dyn_1in)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
......@@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape)
}
}
TEST_CASE(reshape_dyn_2in_0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {8, 8}, {3, 3}, {1, 1}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_dyn_2in_1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape output{migraphx::shape::float_type, {{12, 12}, {2, 2}, {1, 1}, {1, 4}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_dyn_2in_2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output{migraphx::shape::float_type, {{1, 2}, {6, 12}, {1, 1}, {4, 4}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}};
......
......@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_dyn_test)
TEST_CASE(reshape_dyn_1in_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
......@@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape s_out{migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(input_fixed_shape, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_elements_runtime_error)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
// elements do not match up
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 2, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); }));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment