Unverified Commit c3990622 authored by Zakor Gyula's avatar Zakor Gyula Committed by GitHub
Browse files

Add support for Shrink ONNX operator (#2240)

parent 5139b930
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_shrink : op_parser<parse_shrink>
{
std::vector<op_desc> operators() const { return {{"Shrink"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float bias = 0.0;
if(contains(info.attributes, "bias"))
{
bias = parser.parse_value(info.attributes.at("bias")).at<float>();
}
float lambd = 0.5;
if(contains(info.attributes, "lambd"))
{
lambd = parser.parse_value(info.attributes.at("lambd")).at<float>();
}
auto x = args[0];
auto x_shape = x->get_shape();
auto x_type = x_shape.type();
auto lit_bias = info.add_literal(bias);
auto lit_neg_lambd = info.add_literal(-lambd);
auto lit_lambd = info.add_literal(lambd);
auto x_plus_bias = info.add_common_op("add", x, lit_bias);
auto x_min_bias = info.add_common_op("sub", x, lit_bias);
auto cond1 = info.add_common_op("less", x, lit_neg_lambd);
auto cond2_a = info.add_common_op("not", cond1);
auto cond2_b = info.add_common_op("greater", x, lit_lambd);
auto cond2 = info.add_common_op("logical_and", cond2_a, cond2_b);
auto mul1 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond1);
auto mul2 = info.add_instruction(make_op("convert", {{"target_type", x_type}}), cond2);
auto first = info.add_common_op("mul", mul1, x_plus_bias);
auto second = info.add_common_op("mul", mul2, x_min_bias);
auto ret = info.add_common_op("add", first, second);
if(ret->get_shape().type() != x_type)
{
ret = info.add_instruction(make_op("convert", {{"target_type", x_type}}), ret);
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -7135,6 +7135,101 @@ def shape_gather_test(): ...@@ -7135,6 +7135,101 @@ def shape_gather_test():
return ([node_const, node_shape, node_gather], [x], [z]) return ([node_const, node_shape, node_gather], [x], [z])
@onnx_test()
def shrink_hard_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
)
return ([node], [x], [y])
@onnx_test()
def shrink_soft_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)
return ([node], [x], [y])
@onnx_test()
def shrink_verify_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=-5.0,
bias=1.0,
)
return ([node], [x], [y])
@onnx_test()
def shrink_verify2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [5])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=-6.0,
bias=5.0,
)
return ([node], [x], [y])
@onnx_test()
def shrink_int8_test():
x = helper.make_tensor_value_info('x', TensorProto.INT8, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT8, [3, 3])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)
return ([node], [x], [y])
@onnx_test()
def shrink_uint8_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.UINT8, [3, 3])
node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=5.0,
bias=-4.5,
)
return ([node], [x], [y])
@onnx_test() @onnx_test()
def sign_test(): def sign_test():
x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5]) x = helper.make_tensor_value_info('x', TensorProto.DOUBLE, [10, 5])
......
...@@ -7006,6 +7006,73 @@ TEST_CASE(shape_gather_test) ...@@ -7006,6 +7006,73 @@ TEST_CASE(shape_gather_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(shrink_hard_test)
{
migraphx::program p;
float bias = 0.0;
float lambd = 1.5;
std::vector<size_t> lens{5};
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, lens});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});
auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});
auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});
auto mul1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond1);
auto mul2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), cond2);
auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
add_common_op(*mm, migraphx::make_op("add"), {first, second});
auto prog = optimize_onnx("shrink_hard_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(shrink_int8_test)
{
migraphx::program p;
float bias = 1.5;
float lambd = 1.5;
std::vector<size_t> lens{3, 3};
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int8_type, lens});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});
auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});
auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});
auto mul1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond1);
auto mul2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), cond2);
auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second});
mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}),
ret);
auto prog = optimize_onnx("shrink_int8_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(sign_test) TEST_CASE(sign_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1807,6 +1807,112 @@ TEST_CASE(selu_test) ...@@ -1807,6 +1807,112 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(shrink_hard_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_hard_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data{-2, -1, 0, 1, 2};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, 0, 0, 0, 2};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(shrink_soft_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_soft_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::float_type, {5}};
std::vector<float> data{-2, -1, 0, 1, 2};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.5, 0, 0, 0, 0.5};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(shrink_verify_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_verify_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::half_type, {5}};
std::vector<float> tmp = {-10.0, -5.0, 0.0, 5.0, 10.0};
std::vector<migraphx::half> data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {-9.0, -4.0, 1.0, 4.0, 9.0};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(shrink_verify2_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_verify2_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::half_type, {5}};
std::vector<float> tmp = {-10.0, -5.0, 0.0, 5.0, 10.0};
std::vector<migraphx::half> data{tmp.cbegin(), tmp.cend()};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<migraphx::half> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
tmp = {-5.0, 0.0, 5.0, 10.0, 5.0};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(shrink_int8_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_int8_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::int8_type, {3, 3}};
std::vector<int8_t> data{-4, -3, -2, -1, 0, 1, 2, 3, 4};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-2, -1, 0, 0, 0, 0, 0, 1, 2};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(shrink_uint8_test)
{
migraphx::program p = migraphx::parse_onnx("shrink_uint8_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape s{migraphx::shape::uint8_type, {3, 3}};
std::vector<uint8_t> data{1, 2, 3, 4, 5, 6, 7, 8, 9};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<uint8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {0, 0, 0, 0, 0, 10, 11, 12, 13};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(size_verify_test) TEST_CASE(size_verify_test)
{ {
migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx"); migraphx::program p = migraphx::parse_onnx("size_verify_test.onnx");
......
...@@ -249,8 +249,6 @@ def disabled_tests_onnx_1_7_0(backend_test): ...@@ -249,8 +249,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_reversesequence_time_cpu') backend_test.exclude(r'test_reversesequence_time_cpu')
backend_test.exclude(r'test_scan9_sum_cpu') backend_test.exclude(r'test_scan9_sum_cpu')
backend_test.exclude(r'test_scan_sum_cpu') backend_test.exclude(r'test_scan_sum_cpu')
backend_test.exclude(r'test_shrink_hard_cpu')
backend_test.exclude(r'test_shrink_soft_cpu')
backend_test.exclude(r'test_slice_cpu') backend_test.exclude(r'test_slice_cpu')
backend_test.exclude(r'test_slice_default_axes_cpu') backend_test.exclude(r'test_slice_default_axes_cpu')
backend_test.exclude(r'test_slice_default_steps_cpu') backend_test.exclude(r'test_slice_default_steps_cpu')
...@@ -463,7 +461,6 @@ def disabled_tests_onnx_1_7_0(backend_test): ...@@ -463,7 +461,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_sequence_model6_cpu') backend_test.exclude(r'test_sequence_model6_cpu')
backend_test.exclude(r'test_sequence_model7_cpu') backend_test.exclude(r'test_sequence_model7_cpu')
backend_test.exclude(r'test_sequence_model8_cpu') backend_test.exclude(r'test_sequence_model8_cpu')
backend_test.exclude(r'test_shrink_cpu')
backend_test.exclude(r'test_strnorm_model_monday_casesensintive_lower_cpu') backend_test.exclude(r'test_strnorm_model_monday_casesensintive_lower_cpu')
backend_test.exclude( backend_test.exclude(
r'test_strnorm_model_monday_casesensintive_nochangecase_cpu') r'test_strnorm_model_monday_casesensintive_nochangecase_cpu')
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
template <migraphx::shape::type_t T>
struct test_shrink : verify_program<test_shrink<T>>
{
migraphx::program create_program() const
{
migraphx::program p;
float bias = 1.5;
float lambd = 1.5;
auto* mm = p.get_main_module();
migraphx::shape is{T, {2, 3}};
std::vector<float> data;
migraphx::shape::visit(T, [&](auto as) {
as.is_signed() ? data.assign({-3.0, -2.0, -1.0, 0.0, 1.0, 2.0})
: data.assign({3.0, 2.0, 1.0, 0.0, 1.0, 2.0});
});
auto x = mm->add_literal(migraphx::literal{is, data});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd =
mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});
auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});
auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});
auto mul1 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond1);
auto mul2 = mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), cond2);
auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second});
if(ret->get_shape().type() != T)
{
mm->add_instruction(migraphx::make_op("convert", {{"target_type", T}}), ret);
}
return p;
}
};
template struct test_shrink<migraphx::shape::double_type>;
template struct test_shrink<migraphx::shape::float_type>;
template struct test_shrink<migraphx::shape::half_type>;
template struct test_shrink<migraphx::shape::int64_type>;
template struct test_shrink<migraphx::shape::int32_type>;
template struct test_shrink<migraphx::shape::int16_type>;
template struct test_shrink<migraphx::shape::int8_type>;
template struct test_shrink<migraphx::shape::uint64_type>;
template struct test_shrink<migraphx::shape::uint32_type>;
template struct test_shrink<migraphx::shape::uint16_type>;
template struct test_shrink<migraphx::shape::uint8_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment