Unverified Commit 6b6e9362 authored by turneram's avatar turneram Committed by GitHub
Browse files

Add ThresholdedRelu to onnx parser (#937)



Add ability to parse ThresholdedRelu ONNX operator.

Resolves #888
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
parent b43562a7
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -4126,6 +4126,46 @@ def tanh_test():
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'])
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
alpha = 3.0
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'],
alpha=alpha)
return ([node], [x], [y])
@onnx_test
def thresholdedrelu_int_test():
x = helper.make_tensor_value_info('x', TensorProto.INT32, [2, 2, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [2, 2, 3])
alpha = 3.0
node = onnx.helper.make_node('ThresholdedRelu',
inputs=['x'],
outputs=['y'],
alpha=alpha)
return ([node], [x], [y])
@onnx_test
def tile_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2])
......
......@@ -3782,6 +3782,63 @@ TEST_CASE(tanh_test)
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {1.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3.0f}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(thresholdedrelu_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2, 3}});
auto lz = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {0}});
auto la = mm->add_literal(migraphx::literal{migraphx::shape{x->get_shape().type()}, {3}});
auto mbz = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lz);
auto mba = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), la);
auto condition = mm->add_instruction(migraphx::make_op("greater"), x, mba);
mm->add_instruction(migraphx::make_op("where"), condition, x, mbz);
auto prog = optimize_onnx("thresholdedrelu_int_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(tile_test)
{
migraphx::program p;
......
thresholdedrelu_default_test:i

xy"ThresholdedReluthresholdedrelu_default_testZ
x



b
y



B
\ No newline at end of file
......@@ -289,9 +289,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_softplus_example_cpu')
backend_test.exclude(r'test_softsign_cpu')
backend_test.exclude(r'test_softsign_example_cpu')
backend_test.exclude(r'test_thresholdedrelu_cpu')
backend_test.exclude(r'test_thresholdedrelu_default_cpu')
backend_test.exclude(r'test_thresholdedrelu_example_cpu')
backend_test.exclude(r'test_Embedding_cpu')
backend_test.exclude(r'test_Softplus_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment