Unverified Commit 7477aeb8 authored by turneram's avatar turneram Committed by GitHub
Browse files

Add HardSwish op ONNX parser (#1066)

Add HardSwish to HardSigmoid parser

HardSwish formula is y = x * HardSigmoid<alpha=1/6, beta=0.5>(x)
HardSigmoid parser sets alpha to 1/6 and adds the mul instruction if op name is HardSwish

Resolves #1062
parent 60aa1c85
......@@ -10,20 +10,27 @@ namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{
std::vector<op_desc> operators() const { return {{"HardSigmoid"}}; }
std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& /*opd*/,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 0.2;
float beta = 0.5;
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(opd.onnx_name == "HardSwish")
{
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
......@@ -40,9 +47,13 @@ struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
return info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
}
};
......
......@@ -1694,6 +1694,16 @@ def hardsigmoid_verify_test():
return ([node], [x], [y])
@onnx_test
def hardswish_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 5])
node = onnx.helper.make_node('HardSwish', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
@onnx_test
def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
......
hardswish_test:M

xy" HardSwishhardswish_testZ
x


b
y


B
\ No newline at end of file
......@@ -1680,6 +1680,41 @@ TEST_CASE(hardsigmoid_half_test)
EXPECT(p == prog);
}
TEST_CASE(hardswish_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> input_lens{2, 5};
auto input_type = migraphx::shape::float_type;
migraphx::shape s{input_type, input_lens};
auto x = mm->add_parameter("x", s);
float alpha = 1.0 / 6.0;
float beta = 0.5;
auto mb_alpha = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = mm->add_instruction(migraphx::make_op("mul"), mb_alpha, x);
auto add = mm->add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = mm->add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
mm->add_instruction(migraphx::make_op("mul"), x, hardsigmoid);
auto prog = optimize_onnx("hardswish_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_else_test)
{
migraphx::program p;
......
......@@ -119,6 +119,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_globalmaxpool.*')
backend_test.include(r'.*test_greater.*')
backend_test.include(r'.*test_hardsigmoid.*')
backend_test.include(r'.*test_hardswish.*')
backend_test.include(r'.*test_identity.*')
backend_test.include(r'.*test_if.*')
backend_test.include(r'.*test_LeakyReLU*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment