Unverified Commit 49b341d3 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Convert Fp16 instance-norm to FP32 temporarily (#1779)

By converting to fp32 : fp16 3d-unet model accuracy comes out the same as FP32 accuracy.

By using reduce_sum method on Fp16 : accuracy comes out ~0.9% lower compared to fp32 while keeping entire model in fp16.
parent 37711924
...@@ -21,10 +21,14 @@ ...@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,22 +43,43 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -39,22 +43,43 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
...@@ -65,14 +90,29 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -65,14 +90,29 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
// for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
// reduce_sum to calculate variance i.e.
// var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
std::string reduce_op_name =
(dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
if(dtype == shape::half_type and not convert_fp16)
{
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
mean_bcast = info.add_common_op("mul", {mean_bcast, n_literal});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
auto epsilon_bcast = auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
...@@ -82,11 +122,16 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -82,11 +122,16 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast = auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -3176,9 +3176,9 @@ TEST_CASE(instance_norm_test) ...@@ -3176,9 +3176,9 @@ TEST_CASE(instance_norm_test)
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
auto epsilon_literal = auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction( auto epsilon_bcast = mm->add_instruction(
...@@ -3187,7 +3187,7 @@ TEST_CASE(instance_norm_test) ...@@ -3187,7 +3187,7 @@ TEST_CASE(instance_norm_test)
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast); auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2); auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3); auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction( auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
...@@ -3207,33 +3207,41 @@ TEST_CASE(instance_norm_half_test) ...@@ -3207,33 +3207,41 @@ TEST_CASE(instance_norm_half_test)
migraphx::shape s2{migraphx::shape::half_type, {2}}; migraphx::shape s2{migraphx::shape::half_type, {2}};
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("0", s1); auto x_fp16 = mm->add_parameter("0", s1);
auto scale = mm->add_parameter("1", s2); auto scale_fp16 = mm->add_parameter("1", s2);
auto bias = mm->add_parameter("2", s2); auto bias_fp16 = mm->add_parameter("2", s2);
auto x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x_fp16);
auto scale = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), scale_fp16);
auto bias = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), bias_fp16);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast = auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0); auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast); auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
auto epsilon_literal = auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {1e-5}}); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction( auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast); auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2); auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3); auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction( auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction( auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast); auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast); auto instance_norm_fp32 = mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}),
instance_norm_fp32);
auto prog = optimize_onnx("instance_norm_half_test.onnx"); auto prog = optimize_onnx("instance_norm_half_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -74,7 +74,9 @@ int main(int argc, const char* argv[]) ...@@ -74,7 +74,9 @@ int main(int argc, const char* argv[])
"test_select_module_add", "test_select_module_add",
"test_select_module_reduce", "test_select_module_reduce",
"test_select_module_conv", "test_select_module_conv",
"test_split_single_dyn_dim"}); "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"}); rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
migraphx::instruction_ref add_instancenorm(migraphx::module& m,
migraphx::instruction_ref x,
const std::vector<size_t>& dims,
float eps = 1e-5f)
{
auto mgx_type = x->get_shape().type();
auto x_lens = x->get_shape().lens();
std::vector<size_t> axes(x_lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, dims});
auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, dims});
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto l0 = m.add_instruction(migraphx::make_op("sqdiff"), {x, mean_mbcast});
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), {l0});
auto epsilon_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), epsilon);
auto var_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), var);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var_mbcast, epsilon_mbcast);
auto rsqrt = m.add_instruction(migraphx::make_op("rsqrt"), add_epsilon);
auto l1 = m.add_instruction(migraphx::make_op("mul"), {sub, rsqrt});
auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, l1);
auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
}
template <migraphx::shape::type_t TYPE>
struct test_instancenorm : verify_program<test_instancenorm<TYPE>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 5, 5};
auto x = mm->add_parameter("x", migraphx::shape{TYPE, dims});
add_instancenorm(*mm, x, {1, 2, 1, 1});
return p;
}
};
template struct test_instancenorm<migraphx::shape::float_type>;
template struct test_instancenorm<migraphx::shape::half_type>;
template <migraphx::shape::type_t TYPE>
struct test_instancenorm_large_3d : verify_program<test_instancenorm_large_3d<TYPE>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 32, 64, 64, 64};
auto x = mm->add_parameter("x", migraphx::shape{TYPE, dims});
add_instancenorm(*mm, x, {1, 32, 1, 1, 1});
return p;
}
};
template struct test_instancenorm_large_3d<migraphx::shape::float_type>;
template struct test_instancenorm_large_3d<migraphx::shape::half_type>;
...@@ -27,7 +27,9 @@ ...@@ -27,7 +27,9 @@
#include <functional> #include <functional>
#include <migraphx/auto_register.hpp> #include <migraphx/auto_register.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/ranges.hpp>
struct program_info struct program_info
{ {
...@@ -47,7 +49,18 @@ struct register_verify_program_action ...@@ -47,7 +49,18 @@ struct register_verify_program_action
{ {
T x; T x;
program_info pi; program_info pi;
pi.name = migraphx::get_type_name<T>(); const std::string& test_type_name = migraphx::get_type_name<T>();
const auto& split_name = migraphx::split_string(test_type_name, ':');
std::vector<std::string> name_without_version = {};
// test_type_name could contain internal namespace name with version_x_y_z i.e.
// test_instancenorm<migraphx::version_1::shape::float_type> remove version and construct
// test_name such as test_instancenorm<migraphx::shape::float_type>
std::copy_if(
split_name.begin(),
split_name.end(),
std::back_inserter(name_without_version),
[&](const auto& i) { return not i.empty() and not migraphx::contains(i, "version"); });
pi.name = migraphx::trim(migraphx::join_strings(name_without_version, "::"));
pi.section = x.section(); pi.section = x.section();
pi.get_program = [x] { return x.create_program(); }; pi.get_program = [x] { return x.create_program(); };
pi.compile_options = x.get_compile_options(); pi.compile_options = x.get_compile_options();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment