"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "193f105d6b27cc58224790c375895ed7b4e59fe6"
Unverified Commit 5b37c53c authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Celu ONNX parser and tests (#1114)

Add Celu ONNX operator
parent 4467c158
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
celu_alpha_test:R

xy"Celu*
alphaL?celu_alpha_testZ
x

b
y

B
\ No newline at end of file
celu_default_test:K
xy"Celucelu_default_testZ
x


b
y


B
\ No newline at end of file
celu_wrong_type_test:N
xy"Celucelu_wrong_type_testZ
x



b
y



B
\ No newline at end of file
...@@ -351,6 +351,65 @@ def ceil_test(): ...@@ -351,6 +351,65 @@ def ceil_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def celu_alpha_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])
node = onnx.helper.make_node('Celu',
inputs=['x'],
outputs=['y'],
alpha=0.8)
return ([node], [x], [y])
@onnx_test
def celu_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('Celu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
@onnx_test
def celu_verify_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('Celu',
inputs=['x'],
outputs=['y'],
alpha=0.5)
return ([node], [x], [y])
@onnx_test
def celu_wrong_type_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 3])
node = onnx.helper.make_node('Celu', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
@onnx_test
def celu_zero_alpha_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('Celu',
inputs=['x'],
outputs=['y'],
alpha=0.0)
return ([node], [x], [y])
@onnx_test @onnx_test
def clip_test(): def clip_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
...@@ -46,6 +46,29 @@ migraphx::program optimize_onnx(const std::string& name, bool run_passes = false ...@@ -46,6 +46,29 @@ migraphx::program optimize_onnx(const std::string& name, bool run_passes = false
return prog; return prog;
} }
void add_celu_instruction(migraphx::module* mm, const migraphx::shape& s, float alpha)
{
auto x = mm->add_parameter("x", s);
const auto& input_lens = s.lens();
const auto& input_type = s.type();
auto zero_lit =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = mm->add_instruction(migraphx::make_op("max"), zero_lit, x);
auto divi = mm->add_instruction(migraphx::make_op("div"), x, alpha_lit);
auto expo = mm->add_instruction(migraphx::make_op("exp"), divi);
auto sub = mm->add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = mm->add_instruction(migraphx::make_op("min"), zero_lit, mul);
mm->add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
static std::vector<double> make_r_eyelike(size_t num_rows, size_t num_cols, size_t k) static std::vector<double> make_r_eyelike(size_t num_rows, size_t num_cols, size_t k)
{ {
std::vector<double> eyelike_mat(num_rows * num_cols, 0); std::vector<double> eyelike_mat(num_rows * num_cols, 0);
...@@ -380,6 +403,42 @@ TEST_CASE(ceil_test) ...@@ -380,6 +403,42 @@ TEST_CASE(ceil_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(celu_alpha_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> input_lens = {3};
auto input_type = migraphx::shape::float_type;
migraphx::shape s{input_type, input_lens};
float alpha = 0.8;
add_celu_instruction(mm, s, alpha);
auto prog = optimize_onnx("celu_alpha_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(celu_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> input_lens = {2, 3};
auto input_type = migraphx::shape::float_type;
migraphx::shape s{input_type, input_lens};
float alpha = 1.0;
add_celu_instruction(mm, s, alpha);
auto prog = optimize_onnx("celu_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(celu_wrong_type_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("celu_wrong_type_test.onnx"); }));
}
TEST_CASE(celu_zero_alpha_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("celu_zero_alpha_test.onnx"); }));
}
TEST_CASE(clip_test) TEST_CASE(clip_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -45,6 +45,28 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -45,6 +45,28 @@ TEST_CASE(averagepool_nt_cip_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(celu_verify_test)
{
migraphx::program p = migraphx::parse_onnx("celu_verify_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {-5.5, 2.0, 100., 7.0, 0., -1.};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct(6);
float alpha = 0.5;
std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) {
return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha));
});
EXPECT(migraphx::verify_range(result_vector, correct));
}
TEST_CASE(clip_args_type_mismatch) TEST_CASE(clip_args_type_mismatch)
{ {
auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx"); auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx");
......
...@@ -96,6 +96,7 @@ def create_backend_test(testname=None, target_device=None): ...@@ -96,6 +96,7 @@ def create_backend_test(testname=None, target_device=None):
backend_test.include(r'.*test_AvgPool.*') backend_test.include(r'.*test_AvgPool.*')
backend_test.include(r'.*test_BatchNorm.*eval.*') backend_test.include(r'.*test_BatchNorm.*eval.*')
backend_test.include(r'.*test_ceil.*') backend_test.include(r'.*test_ceil.*')
backend_test.include(r'.*test_celu.*')
backend_test.include(r'.*test_clip.*') backend_test.include(r'.*test_clip.*')
backend_test.include(r'.*test_concat.*') backend_test.include(r'.*test_concat.*')
backend_test.include(r'.*test_constant.*') backend_test.include(r'.*test_constant.*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment