"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "1347cb14348be5082930af055bf352da17e2c662"
Unverified Commit 6f1f4b59 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Use `optimize_module` pass for the quantization to fp16 (#1974)

Fixes #1746

BatchNorm only has x as the runtime input parameter for the following equation. All the other parameters are compile-time constants and related operations can be const-folded before quantizing to fp16 to preserve precision.
parent 3216fe52
...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -57,13 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size(); auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2) if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_rank > 2) else if(x_rank > 2)
...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -71,7 +70,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction( auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -81,11 +79,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction( auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]); migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
else else
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) ...@@ -48,19 +49,12 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
// This function is to convert any instructions specified in the input // This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator. // from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing or underflowing, but it
// is very rare in the area of deeping learning, so we just do a // is uncommon. Run optimize_module() before converting to fp16 to const eval and fold in FP32 to
// truncate of the input to get the fp16. // avoid loss of precision.
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
run_passes(prog, run_passes(prog, {optimize_module{}, quantize_fp16_pass{ins_names}, optimize_module{}});
{quantize_fp16_pass{ins_names},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_reshapes{},
dead_code_elimination{},
simplify_qdq{},
dead_code_elimination{}});
} }
void quantize_int8(program& prog, void quantize_int8(program& prog,
......
...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -52,7 +52,6 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = auto scale_unsqueeze =
...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -64,11 +63,11 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto var_unsqueeze = auto var_unsqueeze =
info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]); info.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze); auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps); auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt); auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom); auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze); auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze); return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
}; };
......
...@@ -440,14 +440,13 @@ TEST_CASE(batch_norm_flat_test) ...@@ -440,14 +440,13 @@ TEST_CASE(batch_norm_flat_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {1}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {1}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {1}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {1}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, bias});
auto prog = optimize_onnx("batch_norm_flat_test.onnx"); auto prog = optimize_onnx("batch_norm_flat_test.onnx");
...@@ -465,14 +464,13 @@ TEST_CASE(batch_norm_rank_2_test) ...@@ -465,14 +464,13 @@ TEST_CASE(batch_norm_rank_2_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {5}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {5}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {5}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {5}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, bias});
auto prog = optimize_onnx("batch_norm_rank_2_test.onnx"); auto prog = optimize_onnx("batch_norm_rank_2_test.onnx");
...@@ -490,7 +488,6 @@ TEST_CASE(batch_norm_1d_test) ...@@ -490,7 +488,6 @@ TEST_CASE(batch_norm_1d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-5f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), scale);
...@@ -498,11 +495,11 @@ TEST_CASE(batch_norm_1d_test) ...@@ -498,11 +495,11 @@ TEST_CASE(batch_norm_1d_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_1d_test.onnx"); auto prog = optimize_onnx("batch_norm_1d_test.onnx");
...@@ -520,7 +517,6 @@ TEST_CASE(batch_norm_2d_test) ...@@ -520,7 +517,6 @@ TEST_CASE(batch_norm_2d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}}); auto mean = mm->add_parameter("mean", {migraphx::shape::float_type, {3}});
auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}}); auto var = mm->add_parameter("variance", {migraphx::shape::float_type, {3}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -528,11 +524,11 @@ TEST_CASE(batch_norm_2d_test) ...@@ -528,11 +524,11 @@ TEST_CASE(batch_norm_2d_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_2d_test.onnx"); auto prog = optimize_onnx("batch_norm_2d_test.onnx");
...@@ -550,7 +546,6 @@ TEST_CASE(batch_norm_3d_test) ...@@ -550,7 +546,6 @@ TEST_CASE(batch_norm_3d_test)
auto mean = mm->add_parameter("mean", {migraphx::shape::half_type, {2}}); auto mean = mm->add_parameter("mean", {migraphx::shape::half_type, {2}});
auto var = mm->add_parameter("variance", {migraphx::shape::half_type, {2}}); auto var = mm->add_parameter("variance", {migraphx::shape::half_type, {2}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-6f}});
auto usq_scale = auto usq_scale =
...@@ -561,12 +556,13 @@ TEST_CASE(batch_norm_3d_test) ...@@ -561,12 +556,13 @@ TEST_CASE(batch_norm_3d_test)
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), mean); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_onnx("batch_norm_3d_test.onnx"); auto prog = optimize_onnx("batch_norm_3d_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -908,7 +904,6 @@ TEST_CASE(constant_test) ...@@ -908,7 +904,6 @@ TEST_CASE(constant_test)
TEST_CASE(constant_fill_test) TEST_CASE(constant_fill_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
...@@ -1105,7 +1100,6 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -1105,7 +1100,6 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}}); auto p5 = mm->add_parameter("5", {migraphx::shape::float_type, {1}});
auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}}); auto p6 = mm->add_parameter("6", {migraphx::shape::float_type, {1}});
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
uint64_t axis = 1; uint64_t axis = 1;
...@@ -1120,25 +1114,12 @@ TEST_CASE(conv_bn_relu_maxpool_test) ...@@ -1120,25 +1114,12 @@ TEST_CASE(conv_bn_relu_maxpool_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p5); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p5);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p6); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), p6);
auto mb_mean = mm->add_instruction( auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {l5, usq_mean});
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_mean); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto numer = mm->add_instruction(migraphx::make_op("sub"), l5, mb_mean); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto mb_eps = auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1}}}), eps); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
auto var_eps = mm->add_instruction(migraphx::make_op("add"), usq_var, mb_eps); auto l6 = add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto mb_rt =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 1}}}), rt);
auto denom = mm->add_instruction(migraphx::make_op("pow"), var_eps, mb_rt);
auto mb_denom = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), denom);
auto div0 = mm->add_instruction(migraphx::make_op("div"), numer, mb_denom);
auto mb_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_scale);
auto r0 = mm->add_instruction(migraphx::make_op("mul"), div0, mb_scale);
auto mb_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 1, 28, 28}}}), usq_bias);
auto l6 = mm->add_instruction(migraphx::make_op("add"), r0, mb_bias);
auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6); auto l7 = mm->add_instruction(migraphx::make_op("relu"), l6);
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -7150,7 +7131,8 @@ TEST_CASE(variable_batch_user_input_test6) ...@@ -7150,7 +7131,8 @@ TEST_CASE(variable_batch_user_input_test6)
TEST_CASE(variable_batch_user_input_test7) TEST_CASE(variable_batch_user_input_test7)
{ {
// if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static shape // if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static
// shape
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
......
...@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph) ...@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph)
auto create_fp16_program = [] { auto create_fp16_program = [] {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}}; migraphx::shape sd{migraphx::shape::half_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type}; migraphx::shape sc{migraphx::shape::bool_type};
...@@ -390,8 +387,8 @@ TEST_CASE(fp16_subgraph) ...@@ -390,8 +387,8 @@ TEST_CASE(fp16_subgraph)
auto x = mm->add_parameter("x", sx); auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto hl1 = then_mod->add_instruction( auto hl2 = then_mod->add_literal(migraphx::literal(sd, {2}));
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l1); auto hl1 = then_mod->add_literal(migraphx::literal(sd, {1}));
auto mhl1 = then_mod->add_instruction( auto mhl1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1);
auto hx = then_mod->add_instruction( auto hx = then_mod->add_instruction(
...@@ -399,8 +396,6 @@ TEST_CASE(fp16_subgraph) ...@@ -399,8 +396,6 @@ TEST_CASE(fp16_subgraph)
auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1); auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1);
auto fad = then_mod->add_instruction( auto fad = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad);
auto hl2 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2);
auto mhl2 = then_mod->add_instruction( auto mhl2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2);
auto hy1 = then_mod->add_instruction( auto hy1 = then_mod->add_instruction(
...@@ -411,8 +406,7 @@ TEST_CASE(fp16_subgraph) ...@@ -411,8 +406,7 @@ TEST_CASE(fp16_subgraph)
then_mod->add_return({fad, fmu, mu}); then_mod->add_return({fad, fmu, mu});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto hl3 = else_mod->add_instruction( auto hl3 = else_mod->add_literal(migraphx::literal(sd, {3}));
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l3);
auto mhl3 = else_mod->add_instruction( auto mhl3 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3);
auto hx2 = else_mod->add_instruction( auto hx2 = else_mod->add_instruction(
......
...@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test) ...@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test) ...@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnorm_test.pb", true); auto prog = optimize_tf("batchnorm_test.pb", true);
...@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test) ...@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test) ...@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnorm_half_test.pb", true); auto prog = optimize_tf("batchnorm_half_test.pb", true);
...@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test) ...@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test) ...@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnormv3_test.pb", true); auto prog = optimize_tf("batchnormv3_test.pb", true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment