"src/vscode:/vscode.git/clone" did not exist on "29e60c22948d79837a5156a594f7bd9d4ff5c1b9"
Unverified Commit 6072b2c4 authored by music-dino's avatar music-dino Committed by GitHub
Browse files

Add MeanVarianceNormalization ONNX parsing (#2255)

parent c8f1cd93
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean_variance_normalization : op_parser<parse_mean_variance_normalization>
{
std::vector<op_desc> operators() const { return {{"MeanVarianceNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto&& data = args.front();
auto data_rank = data->get_shape().ndim();
std::vector<int64_t> axes{0, 2, 3};
if(contains(info.attributes, "axes"))
{
const auto& axes_attr = info.attributes["axes"].ints();
axes.assign(axes_attr.begin(), axes_attr.end());
}
else if(data_rank != 4)
{
MIGRAPHX_THROW(
"Input tensor needs to be rank 4 when axes is not specified. Instead it is rank " +
std::to_string(data_rank));
}
if(axes.size() != data_rank - 1)
{
MIGRAPHX_THROW("Length of axes array needs to be equal to input tensor rank - 1");
}
auto data_mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data);
auto data_mean_squared = info.add_common_op("mul", data_mean, data_mean);
auto data_squared = info.add_common_op("mul", data, data);
auto data_squared_mean =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), data_squared);
auto mean_sub = info.add_common_op("sub", data_squared_mean, data_mean_squared);
auto std = info.add_common_op("sqrt", mean_sub);
auto dividend = info.add_common_op("sub", data, data_mean);
auto epsilon =
info.add_literal({data->get_shape().type(),
{data->get_shape().type() == shape::half_type ? 1e-7 : 1e-9}});
auto divisor = info.add_common_op("add", std, epsilon);
return info.add_common_op("div", dividend, divisor);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -4681,6 +4681,77 @@ def mean_integral_test(): ...@@ -4681,6 +4681,77 @@ def mean_integral_test():
return ([node], data, [mean]) return ([node], data, [mean])
def mvn_default_axes_test_base(dims, type=TensorProto.FLOAT):
data = helper.make_tensor_value_info("data", type, dims)
out = helper.make_tensor_value_info("out", type, dims)
node = helper.make_node("MeanVarianceNormalization",
inputs=["data"],
outputs=["out"])
return ([node], [data], [out])
@onnx_test()
def mvn_default_axes_test():
return mvn_default_axes_test_base([2, 2, 2, 2])
@onnx_test()
def mvn_default_axes_fp16_test():
return mvn_default_axes_test_base([2, 2, 2, 2], TensorProto.FLOAT16)
@onnx_test()
def mvn_default_axes_rank_too_small_test():
return mvn_default_axes_test_base([2, 2, 2])
@onnx_test()
def mvn_default_axes_rank_too_big_test():
return mvn_default_axes_test_base([2, 2, 2, 2, 2])
def mvn_n_rank_test_base(axes, dims, type=TensorProto.FLOAT):
data = helper.make_tensor_value_info("data", type, dims)
out = helper.make_tensor_value_info("out", type, dims)
node = helper.make_node("MeanVarianceNormalization",
inputs=["data"],
outputs=["out"],
axes=axes)
return ([node], [data], [out])
@onnx_test()
def mvn_rank_2_test():
return mvn_n_rank_test_base([1], [2, 2])
@onnx_test()
def mvn_rank_2_fp16_test():
return mvn_n_rank_test_base([1], [2, 2], TensorProto.FLOAT16)
@onnx_test()
def mvn_rank_3_test():
return mvn_n_rank_test_base([0, 1], [2, 2, 2])
@onnx_test()
def mvn_rank_3_fp16_test():
return mvn_n_rank_test_base([0, 1], [2, 2, 2], TensorProto.FLOAT16)
@onnx_test()
def mvn_axes_rank_too_small_test():
return mvn_n_rank_test_base([0, 1, 2], [2, 2, 2])
@onnx_test()
def mvn_axes_rank_too_big_test():
return mvn_n_rank_test_base([0], [2, 2, 2])
@onnx_test() @onnx_test()
def min_test(): def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
 mvn_default_axes_fp16_test:
&
dataout"MeanVarianceNormalizationmvn_default_axes_fp16_testZ
data





b
out





B
\ No newline at end of file
 $mvn_default_axes_rank_too_small_test:
&
dataout"MeanVarianceNormalization$mvn_default_axes_rank_too_small_testZ
data



b
out



B
\ No newline at end of file
 mvn_default_axes_test:~
&
dataout"MeanVarianceNormalizationmvn_default_axes_testZ
data




b
out




B
\ No newline at end of file
 mvn_rank_2_fp16_test:z
3
dataout"MeanVarianceNormalization*
axes@mvn_rank_2_fp16_testZ
data



b
out



B
\ No newline at end of file
 mvn_rank_2_test:u
3
dataout"MeanVarianceNormalization*
axes@mvn_rank_2_testZ
data


b
out


B
\ No newline at end of file
...@@ -4501,6 +4501,66 @@ TEST_CASE(mean_integral_test) ...@@ -4501,6 +4501,66 @@ TEST_CASE(mean_integral_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
void mvn_n_rank_test(std::vector<int64_t> axes,
std::vector<size_t> input_shape,
const std::string& test_file)
{
using migraphx::make_op;
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("data", {migraphx::shape::float_type, std::move(input_shape)});
auto data_mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), data);
auto data_mean_squared = add_common_op(*mm, make_op("mul"), {data_mean, data_mean});
auto data_squared = add_common_op(*mm, make_op("mul"), {data, data});
auto data_squared_mean =
mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), data_squared);
auto mean_sub = add_common_op(*mm, make_op("sub"), {data_squared_mean, data_mean_squared});
auto std = add_common_op(*mm, make_op("sqrt"), {mean_sub});
auto dividend = add_common_op(*mm, make_op("sub"), {data, data_mean});
auto epsilon = mm->add_literal({migraphx::shape::float_type, {1e-9}});
auto divisor = add_common_op(*mm, make_op("add"), {std, epsilon});
add_common_op(*mm, make_op("div"), {dividend, divisor});
auto prog = optimize_onnx(test_file);
EXPECT(p == prog);
}
TEST_CASE(mvn_default_axes_test)
{
mvn_n_rank_test({0, 2, 3}, {2, 2, 2, 2}, "mvn_default_axes_test.onnx");
}
TEST_CASE(mvn_default_axes_rank_too_small_test)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("mvn_default_axes_rank_too_small_test.onnx"); }));
}
TEST_CASE(mvn_default_axes_rank_too_big_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_default_axes_rank_too_big_test.onnx"); }));
}
TEST_CASE(mvn_rank_2_test) { mvn_n_rank_test({1}, {2, 2}, "mvn_rank_2_test.onnx"); }
TEST_CASE(mvn_rank_3_test) { mvn_n_rank_test({0, 1}, {2, 2, 2}, "mvn_rank_3_test.onnx"); }
TEST_CASE(mvn_axes_rank_too_small_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_axes_rank_too_small_test.onnx"); }));
}
TEST_CASE(mvn_axes_rank_too_big_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("mvn_axes_rank_too_big_test.onnx"); }));
}
TEST_CASE(min_test) TEST_CASE(min_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test) ...@@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
template <typename T = float>
std::vector<T> mvn_test(std::vector<size_t> data_lens, const std::string& test_file)
{
migraphx::program p = migraphx::parse_onnx(test_file);
p.compile(migraphx::make_target("ref"));
migraphx::shape data_shape(migraphx::shape::get_type<T>{}, std::move(data_lens));
std::vector<T> data(data_shape.elements());
std::iota(begin(data), end(data), 0);
migraphx::parameter_map pm;
pm["data"] = migraphx::argument(data_shape, data.data());
auto result = p.eval(pm).back();
std::vector<T> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
TEST_CASE(mvn_default_axes_test)
{
auto result = mvn_test({2, 2, 2, 2}, "mvn_default_axes_test.onnx");
std::vector<float> gold{-1.32424438,
-1.08347268,
-0.84270097,
-0.60192927,
-1.32424438,
-1.08347268,
-0.84270097,
-0.60192927,
0.60192927,
0.84270097,
1.08347268,
1.32424438,
0.60192927,
0.84270097,
1.08347268,
1.32424438};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mvn_default_axes_fp16_test)
{
using migraphx::half;
auto result = mvn_test<half>({2, 2, 2, 2}, "mvn_default_axes_fp16_test.onnx");
std::vector<half> gold{half{-1.324},
half{-1.084},
half{-0.843},
half{-0.602},
half{-1.324},
half{-1.084},
half{-0.843},
half{-0.602},
half{0.602},
half{0.843},
half{1.084},
half{1.324},
half{0.602},
half{0.843},
half{1.084},
half{1.324}};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mvn_rank_2_test)
{
auto result = mvn_test({2, 2}, "mvn_rank_2_test.onnx");
std::vector<float> gold{-1, 1, -1, 1};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mvn_rank_2_fp16_test)
{
using migraphx::half;
auto result = mvn_test<migraphx::half>({2, 2}, "mvn_rank_2_fp16_test.onnx");
std::vector<migraphx::half> gold{half{-1}, half{1}, half{-1}, half{1}};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mvn_rank_3_test)
{
auto result = mvn_test({2, 2, 2}, "mvn_rank_3_test.onnx");
std::vector<float> gold{-1.34164079,
-1.34164079,
-0.4472136,
-0.4472136,
0.4472136,
0.4472136,
1.34164079,
1.34164079};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mvn_rank_3_fp16_test)
{
using migraphx::half;
auto result = mvn_test<half>({2, 2, 2}, "mvn_rank_3_fp16_test.onnx");
std::vector<half> gold{half{-1.342},
half{-1.342},
half{-0.4473},
half{-0.4473},
half{0.4473},
half{0.4473},
half{1.342},
half{1.342}};
EXPECT(migraphx::verify::verify_rms_range(result, gold));
}
TEST_CASE(mod_test) TEST_CASE(mod_test)
{ {
migraphx::program p = migraphx::parse_onnx("mod_test.onnx"); migraphx::program p = migraphx::parse_onnx("mod_test.onnx");
......
...@@ -154,7 +154,6 @@ def disabled_tests_onnx_1_7_0(backend_test): ...@@ -154,7 +154,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu')
backend_test.exclude(r'test_mod_mixed_sign_int32_cpu') backend_test.exclude(r'test_mod_mixed_sign_int32_cpu')
backend_test.exclude(r'test_mod_mixed_sign_int8_cpu') backend_test.exclude(r'test_mod_mixed_sign_int8_cpu')
backend_test.exclude(r'test_mvn_cpu')
backend_test.exclude( backend_test.exclude(
r'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu' r'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu'
) )
...@@ -803,7 +802,6 @@ def disabled_tests_onnx_1_13_0(backend_test): ...@@ -803,7 +802,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
backend_test.exclude(r'test_group_normalization_example_cpu') backend_test.exclude(r'test_group_normalization_example_cpu')
backend_test.exclude(r'test_group_normalization_example_expanded_cpu') backend_test.exclude(r'test_group_normalization_example_expanded_cpu')
backend_test.exclude(r'test_mish_cpu') backend_test.exclude(r'test_mish_cpu')
backend_test.exclude(r'test_mvn_expanded_ver18_cpu')
backend_test.exclude(r'test_optional_get_element_optional_sequence_cpu') backend_test.exclude(r'test_optional_get_element_optional_sequence_cpu')
backend_test.exclude(r'test_optional_get_element_optional_tensor_cpu') backend_test.exclude(r'test_optional_get_element_optional_tensor_cpu')
backend_test.exclude(r'test_optional_get_element_tensor_cpu') backend_test.exclude(r'test_optional_get_element_tensor_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment