Unverified Commit f25606f9 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

2 Input Reshape `ref` implementation (#2304)

parent a7200610
...@@ -36,6 +36,22 @@ namespace migraphx { ...@@ -36,6 +36,22 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
...@@ -215,32 +231,56 @@ struct reshape ...@@ -215,32 +231,56 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs.front(); auto s0 = inputs.front();
if(s0.dynamic()) if(inputs.size() == 1)
{ {
return dyn_compute_shape(s0); if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
} }
else else
{ {
return static_compute_shape(inputs, n_neg_dims); return inputs.back();
} }
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
assert(dyn_out.computed_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape}; if(args.size() == 1)
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin()); std::copy(input.begin(), input.end(), output.begin());
}); });
return result; return result;
}
else
{
// 2 arg
if(args[0].get_shape().elements() != args[1].get_shape().elements())
{
MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " +
std::to_string(args[0].get_shape().elements()) +
" Output buffer: " + std::to_string(args[1].get_shape().elements()));
}
visit_all(args[1], args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return args[1];
}
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -45,15 +45,25 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -45,15 +45,25 @@ struct parse_reshape : op_parser<parse_reshape>
{ {
literal s = parser.parse_value(info.attributes.at("shape")); literal s = parser.parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
} }
if(args.size() == 2) else
{ {
// 2 inputs
auto s = args[1]->eval(); auto s = args[1]->eval();
check_arg_empty(s, "Reshape: non-constant shape input is not supported"); if(s.empty())
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); {
// arg[1] not eval-able
auto alloc_ins = info.add_instruction(
make_op("allocate", {{"buf_type", args[0]->get_shape().type()}}), args[1]);
return info.add_instruction(make_op("reshape"), args[0], alloc_ins);
}
else
{
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
} }
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
} }
}; };
......
...@@ -6065,6 +6065,24 @@ def reshape_non_standard_test(): ...@@ -6065,6 +6065,24 @@ def reshape_non_standard_test():
return ([trans, res], [x], [y]) return ([trans, res], [x], [y])
@onnx_test()
def reshape_variable_input_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])
@onnx_test()
def reshape_variable_input_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None, 6])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])
@onnx_test() @onnx_test()
def resize_downsample_f_test(): def resize_downsample_f_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
......
...@@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test) ...@@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling", auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}}, {"padding", {2, 2, 2, 2}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {6, 6}}}), {"lengths", {6, 6}}}),
input); input);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins); migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins);
...@@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1}; std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling", auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {2, 2}}, {"stride", {2, 2}},
{"lengths", {6, 6}}}), {"lengths", {6, 6}}}),
ins_pad); ins_pad);
mm->add_return({ret}); mm->add_return({ret});
...@@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test) ...@@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0}; std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input); auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling", auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}}, {"padding", {0, 0, 0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {2, 2}}}), {"lengths", {2, 2}}}),
ins_pad); ins_pad);
mm->add_return({ret}); mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx"); auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
...@@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test) ...@@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling", auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}}, {"padding", {1, 1, 1, 1}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"lengths", {2, 2}}}), {"lengths", {2, 2}}}),
input); input);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins); migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins);
...@@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test) ...@@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}}); auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards", migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}), {{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}),
l0, l0,
l1); l1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}), mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}),
...@@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test) ...@@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards", migraphx::make_op("convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0, l0,
l1); l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
...@@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test) ...@@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}}); auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction( auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards", migraphx::make_op("convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}), {{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0, l0,
l1); l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2); mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
...@@ -1996,7 +1996,7 @@ TEST_CASE(equal_test) ...@@ -1996,7 +1996,7 @@ TEST_CASE(equal_test)
auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2); auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq); eq);
mm->add_return({ret}); mm->add_return({ret});
...@@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test) ...@@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test)
auto input2 = mm->add_parameter("x2", sb); auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction( auto cin1 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1); input1);
auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2); auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2);
mm->add_return({ret}); mm->add_return({ret});
...@@ -2726,7 +2726,7 @@ TEST_CASE(greater_test) ...@@ -2726,7 +2726,7 @@ TEST_CASE(greater_test)
auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2); auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
gr); gr);
mm->add_return({ret}); mm->add_return({ret});
...@@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test) ...@@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test)
auto input2 = mm->add_parameter("x2", sb); auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction( auto cin1 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1); input1);
auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2); auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2);
mm->add_return({ret}); mm->add_return({ret});
...@@ -3602,7 +3602,7 @@ TEST_CASE(less_test) ...@@ -3602,7 +3602,7 @@ TEST_CASE(less_test)
auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2); auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto ret = mm->add_instruction( auto ret = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le); le);
mm->add_return({ret}); mm->add_return({ret});
...@@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test) ...@@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test)
auto input2 = mm->add_parameter("x2", sb); auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction( auto cin1 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1); input1);
auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2); auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2);
mm->add_return({ret}); mm->add_return({ret});
...@@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test) ...@@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test)
// a shape with 4 dynamic dimensions // a shape with 4 dynamic dimensions
auto l0 = mm->add_parameter("x", auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0); auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins = auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins); mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
...@@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test) ...@@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test)
// No axes given in the onnx file. Parser should default to all axes. // No axes given in the onnx file. Parser should default to all axes.
auto l0 = mm->add_parameter("x", auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}}); {{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0); auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins = auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins); mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
...@@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test) ...@@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reshape_variable_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
auto prog = optimize_onnx("reshape_variable_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(reshape_variable_input_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {3, 3}}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
mm->add_return({reshape});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(resize_downsample_c_test) TEST_CASE(resize_downsample_c_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test) ...@@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5}; std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter("0", auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}}); {{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0); auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1); auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
...@@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test) ...@@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test)
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}}); auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = mm->add_instruction( auto cin0 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
input0); input0);
auto cin1 = mm->add_instruction( auto cin1 = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
......
reshape_variable_input_test:p

0
12"Reshapereshape_variable_input_testZ
0



Z
1

b
2


B
\ No newline at end of file
...@@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change) ...@@ -2684,7 +2684,7 @@ TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_dyn_shape) TEST_CASE(reshape_dyn_1in)
{ {
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}}; migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{ for(auto&& new_shape : std::vector<std::vector<int64_t>>{
...@@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape) ...@@ -2708,6 +2708,27 @@ TEST_CASE(reshape_dyn_shape)
} }
} }
TEST_CASE(reshape_dyn_2in_0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {8, 8}, {3, 3}, {1, 1}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_dyn_2in_1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape output{migraphx::shape::float_type, {{12, 12}, {2, 2}, {1, 1}, {1, 4}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_dyn_2in_2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output{migraphx::shape::float_type, {{1, 2}, {6, 12}, {1, 1}, {4, 4}}};
expect_shape(output, migraphx::make_op("reshape"), input, output);
}
TEST_CASE(reshape_multiple_non_fixed_error) TEST_CASE(reshape_multiple_non_fixed_error)
{ {
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}}; migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}};
......
...@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2) ...@@ -153,7 +153,7 @@ TEST_CASE(reshape_test2)
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_dyn_test) TEST_CASE(reshape_dyn_1in_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test) ...@@ -173,3 +173,79 @@ TEST_CASE(reshape_dyn_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_2in_test0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
migraphx::shape s_out{migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(input_fixed_shape, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 4, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
auto result = p.eval(params).back();
EXPECT(result.get_shape() == output_fixed_shape);
std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(reshape_2in_elements_runtime_error)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s_in{migraphx::shape::float_type, {2, 24, 1, 1}};
migraphx::shape s_out{migraphx::shape::float_type, {{2, 4}, {6, 6}, {2, 4}, {1, 1}}};
auto input = mm->add_parameter("X", s_in);
auto output_buffer = mm->add_parameter("Y", s_out);
mm->add_instruction(migraphx::make_op("reshape"), input, output_buffer);
p.compile(migraphx::make_target("ref"));
std::vector<float> gold(48);
std::iota(gold.begin(), gold.end(), -3.);
std::vector<float> buffer(48);
std::iota(buffer.begin(), buffer.end(), 0.);
migraphx::parameter_map params;
// elements do not match up
migraphx::shape output_fixed_shape{migraphx::shape::float_type, {2, 6, 2, 1}};
params["X"] = migraphx::argument(s_in, gold.data());
params["Y"] = migraphx::argument(output_fixed_shape, buffer.data());
EXPECT(test::throws([&] { std::ignore = p.eval(params).back(); }));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment