"...frameworkcontroller/frameworkcontrollerjob-crd-v1.json" did not exist on "ae7a72bc66496a501824c95bc6dc10e6cd45be0a"
Commit c90969eb authored by charlie's avatar charlie
Browse files

Dynamic weights shape test and fix

parent f656ffe7
...@@ -98,8 +98,27 @@ struct convolution ...@@ -98,8 +98,27 @@ struct convolution
if(input.dynamic() or weights.dynamic()) if(input.dynamic() or weights.dynamic())
{ {
std::vector<shape::dynamic_dimension> output_dyn_dims = {input.dyn_dims().at(0), std::vector<shape::dynamic_dimension> output_dyn_dims = {};
input.dyn_dims().at(1)}; if(input.dynamic())
{
output_dyn_dims.push_back(input.dyn_dims().at(0));
}
else
{
auto l = input.lens().at(0);
output_dyn_dims.push_back({l, l, 0});
}
if(weights.dynamic())
{
output_dyn_dims.push_back(weights.dyn_dims().at(0));
}
else
{
auto l = weights.lens().at(0);
output_dyn_dims.push_back({l, l, 0});
}
auto min_spatial_dims = calc_output_lens(input.min_lens(), weights.min_lens()); auto min_spatial_dims = calc_output_lens(input.min_lens(), weights.min_lens());
auto max_spatial_dims = calc_output_lens(input.max_lens(), weights.max_lens()); auto max_spatial_dims = calc_output_lens(input.max_lens(), weights.max_lens());
auto opt_spatial_dims = calc_output_lens(input.opt_lens(), weights.opt_lens()); auto opt_spatial_dims = calc_output_lens(input.opt_lens(), weights.opt_lens());
......
...@@ -1056,6 +1056,83 @@ TEST_CASE(conv_dynamic_img_shape_test) ...@@ -1056,6 +1056,83 @@ TEST_CASE(conv_dynamic_img_shape_test)
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify_range(results_vector, sol));
} }
TEST_CASE(conv_dynamic_weights_shape)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}}),
input,
weights);
p.compile(migraphx::ref::target{});
std::vector<float> a = {0.28007596, 0.46114671, 0.12171969, 0.52260835, 0.40916841, 0.07163955,
0.09896668, 0.98628836, 0.69406788, 0.44868846, 0.64017681, 0.27048886,
0.30187397, 0.07334207, 0.05258557, 0.80747513, 0.81330534, 0.00497161,
0.33005534, 0.08908686, 0.46794691, 0.61768946, 0.55104806, 0.13406187,
0.70244284, 0.61296941, 0.46742536, 0.29712714, 0.91839388, 0.0834397,
0.14476327, 0.37857075, 0.25922384, 0.61620963, 0.69455439, 0.70389431,
0.77388606, 0.1752363, 0.74631394, 0.24604889, 0.53600244, 0.22116457,
0.81217463, 0.10789447, 0.43083784, 0.63371852, 0.69742316, 0.09536905};
std::vector<float> c = {0.98411968,
0.2899219,
0.44638833,
0.30390816,
0.03989896,
0.2445332,
0.32700131,
0.57517075,
0.06956476,
0.93079306,
0.19882314,
0.52940601};
std::vector<float> sol = {1.9939406,
2.2703054,
1.8896171,
2.062202,
2.3035214,
1.629366,
2.1606991,
2.1917608,
1.6797699};
migraphx::shape weight_fixed_shape0{migraphx::shape::float_type, {1, 3, 2, 2}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_shape, a.data());
params0["W"] = migraphx::argument(weight_fixed_shape0, c.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(72);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
c = {0.98411968, 0.2899219, 0.44638833, 0.30390816, 0.03989896, 0.2445332, 0.32700131,
0.57517075, 0.06956476, 0.93079306, 0.19882314, 0.52940601, 0.35624753, 0.35938406,
0.9111428, 0.88923574, 0.61040283, 0.2797513, 0.15479768, 0.46534674, 0.16970931,
0.49704618, 0.07062198, 0.01678321, 0.53150934, 0.39244495, 0.9963813};
sol = {6.1329393, 4.3199925, 5.448438, 3.8497565};
migraphx::shape weights_fixed_shape1{migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::parameter_map params1;
params1["X"] = migraphx::argument(input_shape, a.data());
params1["W"] = migraphx::argument(weights_fixed_shape1, c.data());
result = p.eval(params1).back();
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv2d_padding_stride_test) TEST_CASE(conv2d_padding_stride_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment