"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "48fa934d180cda7b6764b21465e203e39ca1cab3"
Commit a0dd2ef9 authored by charlie's avatar charlie
Browse files

Dynamic image size test

parent a8939c5b
......@@ -967,6 +967,103 @@ TEST_CASE(conv_dynamic_batch_test)
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv_dynamic_img_shape_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape);
auto weights = mm->add_parameter("W", weights_shape);
mm->add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {1, 1}}}),
input,
weights);
p.compile(migraphx::ref::target{});
std::vector<float> a = {0.28007596, 0.46114671, 0.12171969, 0.52260835, 0.40916841, 0.07163955,
0.09896668, 0.98628836, 0.69406788, 0.44868846, 0.64017681, 0.27048886,
0.30187397, 0.07334207, 0.05258557, 0.80747513, 0.81330534, 0.00497161,
0.33005534, 0.08908686, 0.46794691, 0.61768946, 0.55104806, 0.13406187,
0.70244284, 0.61296941, 0.46742536, 0.29712714, 0.91839388, 0.0834397,
0.14476327, 0.37857075, 0.25922384, 0.61620963, 0.69455439, 0.70389431,
0.77388606, 0.1752363, 0.74631394, 0.24604889, 0.53600244, 0.22116457,
0.81217463, 0.10789447, 0.43083784, 0.63371852, 0.69742316, 0.09536905};
std::vector<float> c = {0.98411968, 0.2899219, 0.44638833, 0.30390816, 0.03989896, 0.2445332,
0.32700131, 0.57517075, 0.06956476, 0.93079306, 0.19882314, 0.52940601,
0.35624753, 0.35938406, 0.9111428, 0.88923574, 0.61040283, 0.2797513,
0.15479768, 0.46534674, 0.16970931, 0.49704618, 0.07062198, 0.01678321,
0.53150934, 0.39244495, 0.9963813};
std::vector<float> sol = {6.1329393, 4.3199925, 5.448438, 3.8497565};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
params0["W"] = migraphx::argument(weights_shape, c.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector(72);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
a = {0.95600171, 0.20768181, 0.82844489, 0.14928212, 0.51280462, 0.1359196, 0.68903648,
0.84174772, 0.425509, 0.956926, 0.82533291, 0.33821531, 0.57576055, 0.75330186,
0.82710394, 0.93343847, 0.14499469, 0.74558021, 0.13935139, 0.90652876, 0.22611443,
0.85323975, 0.30631787, 0.96983037, 0.51783421, 0.32247456, 0.28243352, 0.605865,
0.33376446, 0.67864877, 0.15442507, 0.24977552, 0.86989425, 0.60036782, 0.26198306,
0.1494149, 0.13678915, 0.24892094, 0.38282467, 0.64907906, 0.83756376, 0.77603195,
0.33951558, 0.14856874, 0.45701939, 0.43786436, 0.57421759, 0.37326922, 0.63382506,
0.11464436, 0.23309047, 0.76724102, 0.98712427, 0.80800108, 0.84296564, 0.79568268,
0.45684131, 0.73867068, 0.57845499, 0.45073557, 0.27102442, 0.86460315, 0.06865567,
0.81673446, 0.881835, 0.42351639, 0.83322931, 0.34101671, 0.51979151, 0.54920645,
0.19287718, 0.33321689, 0.27752456, 0.45755893, 0.67484562, 0.68383122, 0.52361312,
0.46437257, 0.50862936, 0.32460429, 0.1726007, 0.29933345, 0.64856728, 0.06471591,
0.63370843, 0.27900152, 0.18595992, 0.48904812, 0.35368508, 0.09620202, 0.709561,
0.7916206, 0.0443115, 0.62592275, 0.2498623, 0.42725624, 0.7905135, 0.53160169,
0.01303743, 0.01987505, 0.39041803, 0.89530203, 0.23155373, 0.44435213, 0.14407301,
0.80968594, 0.38216188, 0.35692557};
c = {0.2568538, 0.83587388, 0.43654904, 0.04974508, 0.80375029, 0.25350374, 0.1820275,
0.23369029, 0.54358755, 0.96287212, 0.28424067, 0.45639522, 0.61295404, 0.97581672,
0.95342667, 0.39949156, 0.37287137, 0.42897821, 0.11085312, 0.83015689, 0.88845748,
0.37558172, 0.72528733, 0.74167964, 0.4398981, 0.85575732, 0.97880085};
sol = {6.1561007,
6.7845025,
7.718525,
7.520974,
6.490427,
6.963689,
8.200459,
7.9006085,
7.348745,
6.753414,
7.1623836,
7.8356404,
6.903219,
6.956274,
7.2062597,
7.544957};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 3, 6, 6}};
migraphx::parameter_map params1;
params1["X"] = migraphx::argument(input_fixed_shape1, a.data());
params1["W"] = migraphx::argument(weights_shape, c.data());
result = p.eval(params1).back();
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
}
TEST_CASE(conv2d_padding_stride_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment