Unverified Commit 52c74f0e authored by Attila Dusnoki's avatar Attila Dusnoki Committed by GitHub
Browse files

Add GroupNorm and LayerNorm onnx parsing (#2242)

parent f25606f9
 layer_norm_3d_test:
=
x
scale
biasy"LayerNormalization*
axislayer_norm_3d_testZ
x



Z
scale

Z
bias

b
y



B
\ No newline at end of file
layer_norm_4d_half_test:
=
x
scale
biasy"LayerNormalization*
axislayer_norm_4d_half_testZ
x





Z
scale


Z
bias


b
y





B
\ No newline at end of file
 layer_norm_4d_test:
=
x
scale
biasy"LayerNormalization*
axislayer_norm_4d_testZ
x




Z
scale

Z
bias

b
y




B
\ No newline at end of file
 )layer_norm_invalid_input_count_error_test:q

xy"LayerNormalization)layer_norm_invalid_input_count_error_testZ
x


b
y


B
\ No newline at end of file
 (layer_norm_invalid_minus_axis_error_test:
=
x
scale
biasy"LayerNormalization*
axis(layer_norm_invalid_minus_axis_error_testZ
x



Z
scale



Z
bias



b
y



B
\ No newline at end of file
layer_norm_small_eps_half_test:
4
x
scaley"LayerNormalization*
epsilon̼+layer_norm_small_eps_half_testZ
x



Z
scale


b
y



B
\ No newline at end of file
 layer_norm_without_bias_test:
!
x
scaley"LayerNormalizationlayer_norm_without_bias_testZ
x


Z
scale

b
y


B
\ No newline at end of file
......@@ -2786,6 +2786,145 @@ TEST_CASE(group_conv_test)
EXPECT(p == prog);
}
migraphx::program make_group_norm(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& scale_dims,
const std::vector<int64_t>& bias_dims,
const std::vector<int64_t>& reshape_dims,
const std::vector<int64_t>& reduce_axes,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {dtype, input_dims});
auto scale = mm->add_parameter("scale", {dtype, scale_dims});
auto bias = mm->add_parameter("bias", {dtype, bias_dims});
auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto x_reshaped =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_dims}}), x);
auto mean =
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshaped);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshaped, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshaped, mean});
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}),
x_sqdiff_mean);
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", reshape_dims}}), scale);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", reshape_dims}}), bias);
auto scaled = mm->add_instruction(migraphx::make_op("mul"), {result, scale_bcast});
auto y = mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast});
mm->add_instruction(migraphx::make_op("reshape", {{"dims", input_dims}}), y);
return p;
}
TEST_CASE(group_norm_3d_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::float_type);
auto prog = optimize_onnx("group_norm_3d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_3d_half_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-5f, migraphx::shape::half_type);
auto prog = optimize_onnx("group_norm_3d_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_4d_test)
{
migraphx::program p = make_group_norm(
{1, 4, 3, 3}, {2}, {2}, {1, 2, 2, 3, 3}, {2, 3, 4}, 1e-5f, migraphx::shape::float_type);
auto prog = optimize_onnx("group_norm_4d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_4d_half_test)
{
migraphx::program p = make_group_norm(
{1, 4, 3, 3}, {2}, {2}, {1, 2, 2, 3, 3}, {2, 3, 4}, 1e-5f, migraphx::shape::half_type);
auto prog = optimize_onnx("group_norm_4d_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_5d_test)
{
migraphx::program p = make_group_norm({3, 3, 3, 3, 3},
{1},
{1},
{3, 1, 3, 3, 3, 3},
{2, 3, 4, 5},
1e-5f,
migraphx::shape::float_type);
auto prog = optimize_onnx("group_norm_5d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_5d_half_test)
{
migraphx::program p = make_group_norm({3, 3, 3, 3, 3},
{1},
{1},
{3, 1, 3, 3, 3, 3},
{2, 3, 4, 5},
1e-5f,
migraphx::shape::half_type);
auto prog = optimize_onnx("group_norm_5d_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_small_eps_half_test)
{
migraphx::program p = make_group_norm(
{1, 4, 2}, {2}, {2}, {1, 2, 2, 2}, {2, 3}, 1e-7f, migraphx::shape::half_type);
auto prog = optimize_onnx("group_norm_small_eps_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_norm_invalid_num_groups_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("group_norm_invalid_num_groups_error_test.onnx"); }));
}
TEST_CASE(group_norm_missing_attribute_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("group_norm_missing_attribute_error_test.onnx"); }));
}
TEST_CASE(group_norm_invalid_input_count_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("group_norm_invalid_input_count_error_test.onnx"); }));
}
TEST_CASE(group_norm_invalid_input_shape_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("group_norm_invalid_input_shape_error_test.onnx"); }));
}
TEST_CASE(group_norm_invalid_scale_shape_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_scale_shape_test.onnx"); }));
}
TEST_CASE(group_norm_invalid_bias_shape_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("group_norm_invalid_bias_shape_test.onnx"); }));
}
TEST_CASE(hardsigmoid_default_test)
{
migraphx::program p;
......@@ -3648,6 +3787,149 @@ TEST_CASE(lessorequal_test)
EXPECT(p == prog);
}
migraphx::program make_layer_norm(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& scale_bias_shape,
const std::vector<int64_t>& reduce_axes,
size_t skipped_axis,
bool skip_bias = false,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::float_type)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {dtype, input_shape});
auto scale = mm->add_parameter("scale", {dtype, scale_bias_shape});
migraphx::instruction_ref bias;
if(not skip_bias)
{
bias = mm->add_parameter("bias", {dtype, scale_bias_shape});
}
auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}),
x_sqdiff_mean);
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
migraphx::instruction_ref scale_bcast = scale;
migraphx::instruction_ref bias_bcast = bias;
if(skipped_axis > 0)
{
scale_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", skipped_axis}, {"out_lens", input_shape}}),
scale);
if(not skip_bias)
{
bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", skipped_axis}, {"out_lens", input_shape}}),
bias);
}
}
auto scaled = mm->add_instruction(migraphx::make_op("mul"), {result, scale_bcast});
if(not skip_bias)
{
mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast});
}
return p;
}
TEST_CASE(layer_norm_2d_axis_zero_test)
{
migraphx::program p = make_layer_norm({3, 4}, {3, 4}, {0, 1}, 0);
auto prog = optimize_onnx("layer_norm_2d_axis_zero_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_2d_axis_one_test)
{
migraphx::program p = make_layer_norm({3, 4}, {4}, {1}, 1);
auto prog = optimize_onnx("layer_norm_2d_axis_one_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_2d_axis_minus_one_test)
{
migraphx::program p = make_layer_norm({3, 4}, {4}, {1}, 1);
auto prog = optimize_onnx("layer_norm_2d_axis_one_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_3d_test)
{
migraphx::program p = make_layer_norm({1, 4, 2}, {2}, {2}, 2);
auto prog = optimize_onnx("layer_norm_3d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_3d_half_test)
{
migraphx::program p =
make_layer_norm({1, 4, 2}, {2}, {2}, 2, false, 1e-5f, migraphx::shape::half_type);
auto prog = optimize_onnx("layer_norm_3d_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_4d_test)
{
migraphx::program p = make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3);
auto prog = optimize_onnx("layer_norm_4d_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_4d_half_test)
{
migraphx::program p =
make_layer_norm({3, 3, 3, 3}, {3}, {3}, 3, false, 1e-5f, migraphx::shape::half_type);
auto prog = optimize_onnx("layer_norm_4d_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_invalid_axis_error_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("layer_norm_invalid_axis_error_test.onnx"); }));
}
TEST_CASE(layer_norm_invalid_minus_axis_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("layer_norm_invalid_minus_axis_error_test.onnx"); }));
}
TEST_CASE(layer_norm_invalid_input_count_error_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("layer_norm_invalid_input_count_error_test.onnx"); }));
}
TEST_CASE(layer_norm_without_bias_test)
{
migraphx::program p = make_layer_norm({1, 2}, {2}, {1}, 1, true);
auto prog = optimize_onnx("layer_norm_without_bias_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(layer_norm_small_eps_half_test)
{
migraphx::program p =
make_layer_norm({1, 2}, {2}, {1}, 1, true, 1e-7, migraphx::shape::half_type);
auto prog = optimize_onnx("layer_norm_small_eps_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(log_test)
{
migraphx::program p;
......
......@@ -538,6 +538,70 @@ TEST_CASE(gemm_half_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
template <typename T = float>
std::vector<T> norm_test(const std::vector<size_t>& x_dims,
std::vector<T>& scale,
std::vector<T>& bias,
const std::string& onnx_file)
{
migraphx::program p = migraphx::parse_onnx(onnx_file);
p.compile(migraphx::make_target("ref"));
migraphx::shape s_x{migraphx::shape::get_type<T>{}, x_dims};
migraphx::shape s_s{migraphx::shape::get_type<T>{}, {scale.size()}};
migraphx::shape s_b{migraphx::shape::get_type<T>{}, {scale.size()}};
std::vector<T> x(s_x.elements());
std::iota(std::begin(x), std::end(x), 1);
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_x, x.data());
pp["scale"] = migraphx::argument(s_s, scale.data());
pp["bias"] = migraphx::argument(s_b, bias.data());
auto result = p.eval(pp).back();
std::vector<T> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
TEST_CASE(group_norm_test)
{
std::vector<float> scale{1.2, 0.8};
std::vector<float> bias{0.5, 0.2};
std::vector<float> result_vector =
norm_test<float>({1, 4, 2}, scale, bias, "group_norm_3d_test.onnx");
std::vector<float> gold = {-1.10996256,
-0.0366542,
1.0366542,
2.10996256,
-0.87330837,
-0.15776947,
0.55776947,
1.27330837};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(group_norm_half_test)
{
using migraphx::half;
std::vector<half> scale{half{1.2}, half{0.8}};
std::vector<half> bias{half{0.5}, half{0.2}};
std::vector<half> result_vector =
norm_test<half>({1, 4, 2}, scale, bias, "group_norm_3d_half_test.onnx");
std::vector<half> gold = {half{-1.10996256},
half{-0.0366542},
half{1.0366542},
half{2.10996256},
half{-0.87330837},
half{-0.15776947},
half{0.55776947},
half{1.27330837}};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(greaterorequal_test)
{
migraphx::program p = migraphx::parse_onnx("greaterorequal_test.onnx");
......@@ -950,6 +1014,41 @@ TEST_CASE(instance_norm_3d_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(layer_norm_test)
{
std::vector<float> scale{1.2, 0.8};
std::vector<float> bias{0.5, 0.2};
std::vector<float> result_vector =
norm_test<float>({1, 4, 2}, scale, bias, "layer_norm_3d_test.onnx");
std::vector<float> gold = {-0.69997597,
0.99998398,
-0.69997597,
0.99998398,
-0.69997597,
0.99998398,
-0.69997597,
0.99998398};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(layer_norm_half_test)
{
using migraphx::half;
std::vector<half> scale{half{1.2}, half{0.8}};
std::vector<half> bias{half{0.5}, half{0.2}};
std::vector<half> result_vector =
norm_test<half>({1, 4, 2}, scale, bias, "layer_norm_3d_half_test.onnx");
std::vector<half> gold = {half{-0.69997597},
half{0.99998398},
half{-0.69997597},
half{0.99998398},
half{-0.69997597},
half{0.99998398},
half{-0.69997597},
half{0.99998398}};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(lessorequal_test)
{
migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment