Unverified Commit 20b423e0 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Convert Fp16 instance-norm to FP32 temporarily (#1779) (#1799)

* Convert Fp16 instance-norm to FP32 temporarily (#1779)
* Conditionally enable GeLU approximation  (#1810)
parent 6a7de283
......@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -39,22 +43,43 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
std::vector<instruction_ref> oargs) const
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0];
auto scale = args[1];
auto bias = args[2];
auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double).");
......@@ -65,14 +90,29 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}});
// for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
// reduce_sum to calculate variance i.e.
// var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
std::string reduce_op_name =
(dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
if(dtype == shape::half_type and not convert_fp16)
{
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
mean_bcast = info.add_common_op("mul", {mean_bcast, n_literal});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
......@@ -82,11 +122,16 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
}
};
......
......@@ -74,6 +74,8 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
struct id_pass
{
std::string name() const { return "id"; }
......@@ -120,7 +122,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{},
rewrite_pooling{},
dead_code_elimination{},
rewrite_gelu{},
enable_pass(not enabled(MIGRAPHX_DISABLE_FAST_GELU{}), rewrite_gelu{}),
optimize_module{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
......
......@@ -3176,9 +3176,9 @@ TEST_CASE(instance_norm_test)
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction(
......@@ -3187,7 +3187,7 @@ TEST_CASE(instance_norm_test)
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction(
......@@ -3208,32 +3208,40 @@ TEST_CASE(instance_norm_half_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("0", s1);
auto scale = mm->add_parameter("1", s2);
auto bias = mm->add_parameter("2", s2);
auto x_fp16 = mm->add_parameter("0", s1);
auto scale_fp16 = mm->add_parameter("1", s2);
auto bias_fp16 = mm->add_parameter("2", s2);
auto x = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), x_fp16);
auto scale = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), scale_fp16);
auto bias = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), bias_fp16);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), x);
auto mean_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l0);
auto l1 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l0 = mm->add_instruction(migraphx::make_op("sub"), x, mean_bcast);
auto l1 = mm->add_instruction(migraphx::make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2, 3}}}), l1);
auto epsilon_literal =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {1e-5}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5}});
auto epsilon_bcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = mm->add_instruction(migraphx::make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(migraphx::make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l1, l3);
auto l4 = mm->add_instruction(migraphx::make_op("mul"), l0, l3);
auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = mm->add_instruction(migraphx::make_op("mul"), l4, scale_bcast);
mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
auto instance_norm_fp32 = mm->add_instruction(migraphx::make_op("add"), l5, bias_bcast);
mm->add_instruction(migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}),
instance_norm_fp32);
auto prog = optimize_onnx("instance_norm_half_test.onnx");
EXPECT(p == prog);
......
......@@ -74,7 +74,9 @@ int main(int argc, const char* argv[])
"test_select_module_add",
"test_select_module_reduce",
"test_select_module_conv",
"test_split_single_dyn_dim"});
"test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
migraphx::instruction_ref add_instancenorm(migraphx::module& m,
migraphx::instruction_ref x,
const std::vector<size_t>& dims,
float eps = 1e-5f)
{
auto mgx_type = x->get_shape().type();
auto x_lens = x->get_shape().lens();
std::vector<size_t> axes(x_lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
auto scale = m.add_parameter("scale", migraphx::shape{mgx_type, dims});
auto bias = m.add_parameter("bias", migraphx::shape{mgx_type, dims});
auto epsilon = m.add_literal(migraphx::literal{migraphx::shape{mgx_type}, {eps}});
auto mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), mean);
auto sub = m.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto l0 = m.add_instruction(migraphx::make_op("sqdiff"), {x, mean_mbcast});
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", axes}}), {l0});
auto epsilon_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), epsilon);
auto var_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), var);
auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var_mbcast, epsilon_mbcast);
auto rsqrt = m.add_instruction(migraphx::make_op("rsqrt"), add_epsilon);
auto l1 = m.add_instruction(migraphx::make_op("mul"), {sub, rsqrt});
auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, l1);
auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", x_lens}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
}
template <migraphx::shape::type_t TYPE>
struct test_instancenorm : verify_program<test_instancenorm<TYPE>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 5, 5};
auto x = mm->add_parameter("x", migraphx::shape{TYPE, dims});
add_instancenorm(*mm, x, {1, 2, 1, 1});
return p;
}
};
template struct test_instancenorm<migraphx::shape::float_type>;
template struct test_instancenorm<migraphx::shape::half_type>;
template <migraphx::shape::type_t TYPE>
struct test_instancenorm_large_3d : verify_program<test_instancenorm_large_3d<TYPE>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 32, 64, 64, 64};
auto x = mm->add_parameter("x", migraphx::shape{TYPE, dims});
add_instancenorm(*mm, x, {1, 32, 1, 1, 1});
return p;
}
};
template struct test_instancenorm_large_3d<migraphx::shape::float_type>;
template struct test_instancenorm_large_3d<migraphx::shape::half_type>;
......@@ -27,7 +27,9 @@
#include <functional>
#include <migraphx/auto_register.hpp>
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/ranges.hpp>
struct program_info
{
......@@ -47,7 +49,18 @@ struct register_verify_program_action
{
T x;
program_info pi;
pi.name = migraphx::get_type_name<T>();
const std::string& test_type_name = migraphx::get_type_name<T>();
const auto& split_name = migraphx::split_string(test_type_name, ':');
std::vector<std::string> name_without_version = {};
// test_type_name could contain internal namespace name with version_x_y_z i.e.
// test_instancenorm<migraphx::version_1::shape::float_type> remove version and construct
// test_name such as test_instancenorm<migraphx::shape::float_type>
std::copy_if(
split_name.begin(),
split_name.end(),
std::back_inserter(name_without_version),
[&](const auto& i) { return not i.empty() and not migraphx::contains(i, "version"); });
pi.name = migraphx::trim(migraphx::join_strings(name_without_version, "::"));
pi.section = x.section();
pi.get_program = [x] { return x.create_program(); };
pi.compile_options = x.get_compile_options();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment