Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c90969eb
Commit
c90969eb
authored
May 25, 2022
by
charlie
Browse files
Dynamic weights shape test and fix
parent
f656ffe7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
2 deletions
+98
-2
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+21
-2
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+77
-0
No files found.
src/include/migraphx/op/convolution.hpp
View file @
c90969eb
...
...
@@ -98,8 +98,27 @@ struct convolution
if
(
input
.
dynamic
()
or
weights
.
dynamic
())
{
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{
input
.
dyn_dims
().
at
(
0
),
input
.
dyn_dims
().
at
(
1
)};
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
=
{};
if
(
input
.
dynamic
())
{
output_dyn_dims
.
push_back
(
input
.
dyn_dims
().
at
(
0
));
}
else
{
auto
l
=
input
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
}
if
(
weights
.
dynamic
())
{
output_dyn_dims
.
push_back
(
weights
.
dyn_dims
().
at
(
0
));
}
else
{
auto
l
=
weights
.
lens
().
at
(
0
);
output_dyn_dims
.
push_back
({
l
,
l
,
0
});
}
auto
min_spatial_dims
=
calc_output_lens
(
input
.
min_lens
(),
weights
.
min_lens
());
auto
max_spatial_dims
=
calc_output_lens
(
input
.
max_lens
(),
weights
.
max_lens
());
auto
opt_spatial_dims
=
calc_output_lens
(
input
.
opt_lens
(),
weights
.
opt_lens
());
...
...
test/ref_ops_test.cpp
View file @
c90969eb
...
...
@@ -1056,6 +1056,83 @@ TEST_CASE(conv_dynamic_img_shape_test)
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv_dynamic_weights_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}}),
input,
weights);
p.compile(migraphx::ref::target{});
std::vector<float> a = {0.28007596, 0.46114671, 0.12171969, 0.52260835, 0.40916841, 0.07163955,
0.09896668, 0.98628836, 0.69406788, 0.44868846, 0.64017681, 0.27048886,
0.30187397, 0.07334207, 0.05258557, 0.80747513, 0.81330534, 0.00497161,
0.33005534, 0.08908686, 0.46794691, 0.61768946, 0.55104806, 0.13406187,
0.70244284, 0.61296941, 0.46742536, 0.29712714, 0.91839388, 0.0834397,
0.14476327, 0.37857075, 0.25922384, 0.61620963, 0.69455439, 0.70389431,
0.77388606, 0.1752363, 0.74631394, 0.24604889, 0.53600244, 0.22116457,
0.81217463, 0.10789447, 0.43083784, 0.63371852, 0.69742316, 0.09536905};
std::vector<float> c = {0.98411968,
0.2899219,
0.44638833,
0.30390816,
0.03989896,
0.2445332,
0.32700131,
0.57517075,
0.06956476,
0.93079306,
0.19882314,
0.52940601};
std::vector<float> sol = {1.9939406,
2.2703054,
1.8896171,
2.062202,
2.3035214,
1.629366,
2.1606991,
2.1917608,
1.6797699};
migraphx::shape weight_fixed_shape0{migraphx::shape::float_type, {1, 3, 2, 2}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_shape, a.data());
params0["W"] = migraphx::argument(weight_fixed_shape0, c.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(72);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
c = {0.98411968, 0.2899219, 0.44638833, 0.30390816, 0.03989896, 0.2445332, 0.32700131,
0.57517075, 0.06956476, 0.93079306, 0.19882314, 0.52940601, 0.35624753, 0.35938406,
0.9111428, 0.88923574, 0.61040283, 0.2797513, 0.15479768, 0.46534674, 0.16970931,
0.49704618, 0.07062198, 0.01678321, 0.53150934, 0.39244495, 0.9963813};
sol = {6.1329393, 4.3199925, 5.448438, 3.8497565};
migraphx::shape weights_fixed_shape1{migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::parameter_map params1;
params1["X"] = migraphx::argument(input_shape, a.data());
params1["W"] = migraphx::argument(weights_fixed_shape1, c.data());
result = p.eval(params1).back();
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv2d_padding_stride_test)
{
migraphx::program p;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment