Unverified Commit 03225b57 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Lp normalization op (#1129)

* LpNormalization ONNX parser
parent 548783c8
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
*/
struct parse_lpnormalization : op_parser<parse_lpnormalization>
{
std::vector<op_desc> operators() const { return {{"LpNormalization"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int p = 2;
if(contains(info.attributes, "p"))
{
p = info.attributes.at("p").i();
}
if(p != 1 and p != 2)
{
MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported");
}
auto input = args.front();
auto input_shape = input->get_shape();
const auto& input_lens = input_shape.lens();
auto input_type = input_shape.type();
std::ptrdiff_t num_axes = input_lens.size();
std::ptrdiff_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
if(axis < -num_axes or axis >= num_axes)
{
// handled in normalize_attributes but throwing here might be clearer
MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds");
}
}
migraphx::instruction_ref p_val;
if(p == 1)
{
p_val = info.add_instruction(migraphx::make_op("abs"), input);
}
else
{
p_val = info.add_instruction(migraphx::make_op("mul"), input, input);
}
// need to check for zeros from lp norm to prevent division by zero
// change them to 1 for the element-wise division
auto norms =
info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
if(p == 2)
{
norms = info.add_instruction(migraphx::make_op("sqrt"), norms);
}
// broadcast back to initial shape, negative axis option doesn't work with unidirectional
norms = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2804,6 +2804,70 @@ def loop_test(): ...@@ -2804,6 +2804,70 @@ def loop_test():
return ([node], [iter, cond, a, b], [b_loop, uout]) return ([node], [iter, cond, a, b], [b_loop, uout])
@onnx_test
def lpnormalization_axis_error_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('LpNormalization',
inputs=['x'],
outputs=['y'],
axis=2)
return ([node], [x], [y])
@onnx_test
def lpnormalization_default_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
node = onnx.helper.make_node(
'LpNormalization',
inputs=['x'],
outputs=['y'],
axis=0,
)
return ([node], [x], [y])
@onnx_test
def lpnormalization_l1_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
node = onnx.helper.make_node(
'LpNormalization',
inputs=['x'],
outputs=['y'],
p=1,
)
return ([node], [x], [y])
@onnx_test
def lpnormalization_l2_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
node = onnx.helper.make_node('LpNormalization',
inputs=['x'],
outputs=['y'],
p=2)
return ([node], [x], [y])
@onnx_test
def lpnormalization_p_error_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 3])
node = onnx.helper.make_node('LpNormalization',
inputs=['x'],
outputs=['y'],
p=3)
return ([node], [x], [y])
@onnx_test @onnx_test
def lrn_test(): def lrn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24])
......
lpnormalization_axis_error_test:q
$
xy"LpNormalization*
axis lpnormalization_axis_error_testZ
x


b
y


B
\ No newline at end of file
lpnormalization_l1_test:f
!
xy"LpNormalization*
plpnormalization_l1_testZ
x


b
y


B
\ No newline at end of file
lpnormalization_l2_test:f
!
xy"LpNormalization*
plpnormalization_l2_testZ
x


b
y


B
\ No newline at end of file
lpnormalization_p_error_test:k
!
xy"LpNormalization*
plpnormalization_p_error_testZ
x


b
y


B
\ No newline at end of file
...@@ -2556,6 +2556,46 @@ TEST_CASE(loop_test) ...@@ -2556,6 +2556,46 @@ TEST_CASE(loop_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(lpnormalization_default_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<std::size_t> input_lens{3, 4};
auto input_type = migraphx::shape::float_type;
migraphx::shape s{input_type, input_lens};
auto x = mm->add_parameter("x", s);
std::ptrdiff_t axis = 0;
auto p_val = mm->add_instruction(migraphx::make_op("mul"), x, x);
auto norms = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
norms = mm->add_instruction(migraphx::make_op("sqrt"), norms);
norms =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
mm->add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = mm->add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
mm->add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
mm->add_instruction(migraphx::make_op("div"), x, norms_zeros_to_one);
auto prog = optimize_onnx("lpnormalization_default_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lpnormalization_axis_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_axis_error_test.onnx"); }));
}
TEST_CASE(lpnormalization_p_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("lpnormalization_p_error_test.onnx"); }));
}
TEST_CASE(lrn_test) TEST_CASE(lrn_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -467,6 +467,62 @@ TEST_CASE(lessorequal_test) ...@@ -467,6 +467,62 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_1norm)
{
migraphx::program p = migraphx::parse_onnx("lpnormalization_l1_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.f,
2.f / 5.f,
-2.f / 5.f,
1.f / 5.f,
1.f / 10.f,
-5.f / 10.f,
3.f / 10.f,
-1.f / 10.f,
-4.f / 7.f,
3.f / 7.f,
0.f,
0.f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(lpnormalization_2norm)
{
migraphx::program p = migraphx::parse_onnx("lpnormalization_l2_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::float_type, {3, 4}};
std::vector<float> data{0.f, 2.f, -2.f, 1.f, 1.f, -5.f, 3.f, -1.f, -4.f, 3.f, 0.f, 0.f};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct{0.f,
2.f / 3.f,
-2.f / 3.f,
1.f / 3.f,
1.f / 6.f,
-5.f / 6.f,
3.f / 6.f,
-1.f / 6.f,
-4.f / 5.f,
3.f / 5.f,
0.f,
0.f};
EXPECT(migraphx::verify_range(result_vector, correct));
}
TEST_CASE(mean_broadcast_test) TEST_CASE(mean_broadcast_test)
{ {
migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx"); migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment