Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6072b2c4
Unverified
Commit
6072b2c4
authored
Oct 20, 2023
by
music-dino
Committed by
GitHub
Oct 19, 2023
Browse files
Add MeanVarianceNormalization ONNX parsing (#2255)
parent
c8f1cd93
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
397 additions
and
2 deletions
+397
-2
src/onnx/parse_mean_variance_normalization.cpp
src/onnx/parse_mean_variance_normalization.cpp
+86
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+71
-0
test/onnx/mvn_axes_rank_too_big_test.onnx
test/onnx/mvn_axes_rank_too_big_test.onnx
+0
-0
test/onnx/mvn_axes_rank_too_small_test.onnx
test/onnx/mvn_axes_rank_too_small_test.onnx
+0
-0
test/onnx/mvn_default_axes_fp16_test.onnx
test/onnx/mvn_default_axes_fp16_test.onnx
+17
-0
test/onnx/mvn_default_axes_rank_too_small_test.onnx
test/onnx/mvn_default_axes_rank_too_small_test.onnx
+13
-0
test/onnx/mvn_default_axes_test.onnx
test/onnx/mvn_default_axes_test.onnx
+15
-0
test/onnx/mvn_rank_2_fp16_test.onnx
test/onnx/mvn_rank_2_fp16_test.onnx
+14
-0
test/onnx/mvn_rank_2_test.onnx
test/onnx/mvn_rank_2_test.onnx
+12
-0
test/onnx/mvn_rank_3_fp16_test.onnx
test/onnx/mvn_rank_3_fp16_test.onnx
+0
-0
test/onnx/mvn_rank_3_test.onnx
test/onnx/mvn_rank_3_test.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+60
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+109
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-2
No files found.
src/onnx/parse_mean_variance_normalization.cpp
0 → 100644
View file @
6072b2c4
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx/checks.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_mean_variance_normalization
:
op_parser
<
parse_mean_variance_normalization
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"MeanVarianceNormalization"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
auto
&&
data
=
args
.
front
();
auto
data_rank
=
data
->
get_shape
().
ndim
();
std
::
vector
<
int64_t
>
axes
{
0
,
2
,
3
};
if
(
contains
(
info
.
attributes
,
"axes"
))
{
const
auto
&
axes_attr
=
info
.
attributes
[
"axes"
].
ints
();
axes
.
assign
(
axes_attr
.
begin
(),
axes_attr
.
end
());
}
else
if
(
data_rank
!=
4
)
{
MIGRAPHX_THROW
(
"Input tensor needs to be rank 4 when axes is not specified. Instead it is rank "
+
std
::
to_string
(
data_rank
));
}
if
(
axes
.
size
()
!=
data_rank
-
1
)
{
MIGRAPHX_THROW
(
"Length of axes array needs to be equal to input tensor rank - 1"
);
}
auto
data_mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data
);
auto
data_mean_squared
=
info
.
add_common_op
(
"mul"
,
data_mean
,
data_mean
);
auto
data_squared
=
info
.
add_common_op
(
"mul"
,
data
,
data
);
auto
data_squared_mean
=
info
.
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data_squared
);
auto
mean_sub
=
info
.
add_common_op
(
"sub"
,
data_squared_mean
,
data_mean_squared
);
auto
std
=
info
.
add_common_op
(
"sqrt"
,
mean_sub
);
auto
dividend
=
info
.
add_common_op
(
"sub"
,
data
,
data_mean
);
auto
epsilon
=
info
.
add_literal
({
data
->
get_shape
().
type
(),
{
data
->
get_shape
().
type
()
==
shape
::
half_type
?
1e-7
:
1e-9
}});
auto
divisor
=
info
.
add_common_op
(
"add"
,
std
,
epsilon
);
return
info
.
add_common_op
(
"div"
,
dividend
,
divisor
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
test/onnx/gen_onnx.py
View file @
6072b2c4
...
@@ -4681,6 +4681,77 @@ def mean_integral_test():
...
@@ -4681,6 +4681,77 @@ def mean_integral_test():
return
([
node
],
data
,
[
mean
])
return
([
node
],
data
,
[
mean
])
def
mvn_default_axes_test_base
(
dims
,
type
=
TensorProto
.
FLOAT
):
data
=
helper
.
make_tensor_value_info
(
"data"
,
type
,
dims
)
out
=
helper
.
make_tensor_value_info
(
"out"
,
type
,
dims
)
node
=
helper
.
make_node
(
"MeanVarianceNormalization"
,
inputs
=
[
"data"
],
outputs
=
[
"out"
])
return
([
node
],
[
data
],
[
out
])
@
onnx_test
()
def
mvn_default_axes_test
():
return
mvn_default_axes_test_base
([
2
,
2
,
2
,
2
])
@
onnx_test
()
def
mvn_default_axes_fp16_test
():
return
mvn_default_axes_test_base
([
2
,
2
,
2
,
2
],
TensorProto
.
FLOAT16
)
@
onnx_test
()
def
mvn_default_axes_rank_too_small_test
():
return
mvn_default_axes_test_base
([
2
,
2
,
2
])
@
onnx_test
()
def
mvn_default_axes_rank_too_big_test
():
return
mvn_default_axes_test_base
([
2
,
2
,
2
,
2
,
2
])
def
mvn_n_rank_test_base
(
axes
,
dims
,
type
=
TensorProto
.
FLOAT
):
data
=
helper
.
make_tensor_value_info
(
"data"
,
type
,
dims
)
out
=
helper
.
make_tensor_value_info
(
"out"
,
type
,
dims
)
node
=
helper
.
make_node
(
"MeanVarianceNormalization"
,
inputs
=
[
"data"
],
outputs
=
[
"out"
],
axes
=
axes
)
return
([
node
],
[
data
],
[
out
])
@
onnx_test
()
def
mvn_rank_2_test
():
return
mvn_n_rank_test_base
([
1
],
[
2
,
2
])
@
onnx_test
()
def
mvn_rank_2_fp16_test
():
return
mvn_n_rank_test_base
([
1
],
[
2
,
2
],
TensorProto
.
FLOAT16
)
@
onnx_test
()
def
mvn_rank_3_test
():
return
mvn_n_rank_test_base
([
0
,
1
],
[
2
,
2
,
2
])
@
onnx_test
()
def
mvn_rank_3_fp16_test
():
return
mvn_n_rank_test_base
([
0
,
1
],
[
2
,
2
,
2
],
TensorProto
.
FLOAT16
)
@
onnx_test
()
def
mvn_axes_rank_too_small_test
():
return
mvn_n_rank_test_base
([
0
,
1
,
2
],
[
2
,
2
,
2
])
@
onnx_test
()
def
mvn_axes_rank_too_big_test
():
return
mvn_n_rank_test_base
([
0
],
[
2
,
2
,
2
])
@
onnx_test
()
@
onnx_test
()
def
min_test
():
def
min_test
():
a
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
a
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
3
])
...
...
test/onnx/mvn_axes_rank_too_big_test.onnx
0 → 100644
View file @
6072b2c4
File added
test/onnx/mvn_axes_rank_too_small_test.onnx
0 → 100644
View file @
6072b2c4
File added
test/onnx/mvn_default_axes_fp16_test.onnx
0 → 100644
View file @
6072b2c4
mvn_default_axes_fp16_test:
&
dataout"MeanVarianceNormalizationmvn_default_axes_fp16_testZ
data
b
out
B
\ No newline at end of file
test/onnx/mvn_default_axes_rank_too_small_test.onnx
0 → 100644
View file @
6072b2c4
$mvn_default_axes_rank_too_small_test:
&
dataout"MeanVarianceNormalization$mvn_default_axes_rank_too_small_testZ
data
b
out
B
\ No newline at end of file
test/onnx/mvn_default_axes_test.onnx
0 → 100644
View file @
6072b2c4
mvn_default_axes_test:~
&
dataout"MeanVarianceNormalizationmvn_default_axes_testZ
data
b
out
B
\ No newline at end of file
test/onnx/mvn_rank_2_fp16_test.onnx
0 → 100644
View file @
6072b2c4
mvn_rank_2_fp16_test:z
3
dataout"MeanVarianceNormalization*
axes@mvn_rank_2_fp16_testZ
data
b
out
B
\ No newline at end of file
test/onnx/mvn_rank_2_test.onnx
0 → 100644
View file @
6072b2c4
mvn_rank_2_test:u
3
dataout"MeanVarianceNormalization*
axes@mvn_rank_2_testZ
data
b
out
B
\ No newline at end of file
test/onnx/mvn_rank_3_fp16_test.onnx
0 → 100644
View file @
6072b2c4
File added
test/onnx/mvn_rank_3_test.onnx
0 → 100644
View file @
6072b2c4
File added
test/onnx/onnx_test.cpp
View file @
6072b2c4
...
@@ -4501,6 +4501,66 @@ TEST_CASE(mean_integral_test)
...
@@ -4501,6 +4501,66 @@ TEST_CASE(mean_integral_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
void
mvn_n_rank_test
(
std
::
vector
<
int64_t
>
axes
,
std
::
vector
<
size_t
>
input_shape
,
const
std
::
string
&
test_file
)
{
using
migraphx
::
make_op
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
data
=
mm
->
add_parameter
(
"data"
,
{
migraphx
::
shape
::
float_type
,
std
::
move
(
input_shape
)});
auto
data_mean
=
mm
->
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data
);
auto
data_mean_squared
=
add_common_op
(
*
mm
,
make_op
(
"mul"
),
{
data_mean
,
data_mean
});
auto
data_squared
=
add_common_op
(
*
mm
,
make_op
(
"mul"
),
{
data
,
data
});
auto
data_squared_mean
=
mm
->
add_instruction
(
make_op
(
"reduce_mean"
,
{{
"axes"
,
axes
}}),
data_squared
);
auto
mean_sub
=
add_common_op
(
*
mm
,
make_op
(
"sub"
),
{
data_squared_mean
,
data_mean_squared
});
auto
std
=
add_common_op
(
*
mm
,
make_op
(
"sqrt"
),
{
mean_sub
});
auto
dividend
=
add_common_op
(
*
mm
,
make_op
(
"sub"
),
{
data
,
data_mean
});
auto
epsilon
=
mm
->
add_literal
({
migraphx
::
shape
::
float_type
,
{
1e-9
}});
auto
divisor
=
add_common_op
(
*
mm
,
make_op
(
"add"
),
{
std
,
epsilon
});
add_common_op
(
*
mm
,
make_op
(
"div"
),
{
dividend
,
divisor
});
auto
prog
=
optimize_onnx
(
test_file
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
mvn_default_axes_test
)
{
mvn_n_rank_test
({
0
,
2
,
3
},
{
2
,
2
,
2
,
2
},
"mvn_default_axes_test.onnx"
);
}
TEST_CASE
(
mvn_default_axes_rank_too_small_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_default_axes_rank_too_small_test.onnx"
);
}));
}
TEST_CASE
(
mvn_default_axes_rank_too_big_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_default_axes_rank_too_big_test.onnx"
);
}));
}
TEST_CASE
(
mvn_rank_2_test
)
{
mvn_n_rank_test
({
1
},
{
2
,
2
},
"mvn_rank_2_test.onnx"
);
}
TEST_CASE
(
mvn_rank_3_test
)
{
mvn_n_rank_test
({
0
,
1
},
{
2
,
2
,
2
},
"mvn_rank_3_test.onnx"
);
}
TEST_CASE
(
mvn_axes_rank_too_small_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_axes_rank_too_small_test.onnx"
);
}));
}
TEST_CASE
(
mvn_axes_rank_too_big_test
)
{
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"mvn_axes_rank_too_big_test.onnx"
);
}));
}
TEST_CASE
(
min_test
)
TEST_CASE
(
min_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/onnx/verify_onnx.cpp
View file @
6072b2c4
...
@@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test)
...
@@ -1211,6 +1211,115 @@ TEST_CASE(mean_integral_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
}
template
<
typename
T
=
float
>
std
::
vector
<
T
>
mvn_test
(
std
::
vector
<
size_t
>
data_lens
,
const
std
::
string
&
test_file
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
test_file
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
shape
data_shape
(
migraphx
::
shape
::
get_type
<
T
>
{},
std
::
move
(
data_lens
));
std
::
vector
<
T
>
data
(
data_shape
.
elements
());
std
::
iota
(
begin
(
data
),
end
(
data
),
0
);
migraphx
::
parameter_map
pm
;
pm
[
"data"
]
=
migraphx
::
argument
(
data_shape
,
data
.
data
());
auto
result
=
p
.
eval
(
pm
).
back
();
std
::
vector
<
T
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
return
result_vector
;
}
TEST_CASE
(
mvn_default_axes_test
)
{
auto
result
=
mvn_test
({
2
,
2
,
2
,
2
},
"mvn_default_axes_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1.32424438
,
-
1.08347268
,
-
0.84270097
,
-
0.60192927
,
-
1.32424438
,
-
1.08347268
,
-
0.84270097
,
-
0.60192927
,
0.60192927
,
0.84270097
,
1.08347268
,
1.32424438
,
0.60192927
,
0.84270097
,
1.08347268
,
1.32424438
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_default_axes_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
half
>
({
2
,
2
,
2
,
2
},
"mvn_default_axes_fp16_test.onnx"
);
std
::
vector
<
half
>
gold
{
half
{
-
1.324
},
half
{
-
1.084
},
half
{
-
0.843
},
half
{
-
0.602
},
half
{
-
1.324
},
half
{
-
1.084
},
half
{
-
0.843
},
half
{
-
0.602
},
half
{
0.602
},
half
{
0.843
},
half
{
1.084
},
half
{
1.324
},
half
{
0.602
},
half
{
0.843
},
half
{
1.084
},
half
{
1.324
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_2_test
)
{
auto
result
=
mvn_test
({
2
,
2
},
"mvn_rank_2_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1
,
1
,
-
1
,
1
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_2_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
migraphx
::
half
>
({
2
,
2
},
"mvn_rank_2_fp16_test.onnx"
);
std
::
vector
<
migraphx
::
half
>
gold
{
half
{
-
1
},
half
{
1
},
half
{
-
1
},
half
{
1
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_3_test
)
{
auto
result
=
mvn_test
({
2
,
2
,
2
},
"mvn_rank_3_test.onnx"
);
std
::
vector
<
float
>
gold
{
-
1.34164079
,
-
1.34164079
,
-
0.4472136
,
-
0.4472136
,
0.4472136
,
0.4472136
,
1.34164079
,
1.34164079
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mvn_rank_3_fp16_test
)
{
using
migraphx
::
half
;
auto
result
=
mvn_test
<
half
>
({
2
,
2
,
2
},
"mvn_rank_3_fp16_test.onnx"
);
std
::
vector
<
half
>
gold
{
half
{
-
1.342
},
half
{
-
1.342
},
half
{
-
0.4473
},
half
{
-
0.4473
},
half
{
0.4473
},
half
{
0.4473
},
half
{
1.342
},
half
{
1.342
}};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result
,
gold
));
}
TEST_CASE
(
mod_test
)
TEST_CASE
(
mod_test
)
{
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"mod_test.onnx"
);
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"mod_test.onnx"
);
...
...
test/py/onnx_backend_test.py
View file @
6072b2c4
...
@@ -154,7 +154,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
...
@@ -154,7 +154,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_maxunpool_export_without_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_maxunpool_export_without_output_shape_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int32_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int32_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int8_cpu'
)
backend_test
.
exclude
(
r
'test_mod_mixed_sign_int8_cpu'
)
backend_test
.
exclude
(
r
'test_mvn_cpu'
)
backend_test
.
exclude
(
backend_test
.
exclude
(
r
'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu'
r
'test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_cpu'
)
)
...
@@ -803,7 +802,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
...
@@ -803,7 +802,6 @@ def disabled_tests_onnx_1_13_0(backend_test):
backend_test
.
exclude
(
r
'test_group_normalization_example_cpu'
)
backend_test
.
exclude
(
r
'test_group_normalization_example_cpu'
)
backend_test
.
exclude
(
r
'test_group_normalization_example_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_group_normalization_example_expanded_cpu'
)
backend_test
.
exclude
(
r
'test_mish_cpu'
)
backend_test
.
exclude
(
r
'test_mish_cpu'
)
backend_test
.
exclude
(
r
'test_mvn_expanded_ver18_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_sequence_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_tensor_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_optional_tensor_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_tensor_cpu'
)
backend_test
.
exclude
(
r
'test_optional_get_element_tensor_cpu'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment