Unverified Commit 97d4bb6c authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into add_parity_check_ci

parents 39b097c7 bdbc38bc
This diff is collapsed.
...@@ -453,37 +453,143 @@ TEST_CASE(contiguous_shape_singleton_dim) ...@@ -453,37 +453,143 @@ TEST_CASE(contiguous_shape_singleton_dim)
expect_shape(output, migraphx::make_op("contiguous"), input); expect_shape(output, migraphx::make_op("contiguous"), input);
} }
TEST_CASE(deconvolution_shape) TEST_CASE(convolution_backwards_1d)
{
migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}};
migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}};
expect_shape(output_1d,
migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input_1d,
weights_1d);
}
TEST_CASE(convolution_backwards_2d)
{ {
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape output{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("convolution_backwards"), input, weights);
throws_shape(migraphx::make_op("convolution_backwards"), input);
throws_shape(migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input);
}
TEST_CASE(convolution_backwards_1padding)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}}; migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
expect_shape(output, migraphx::make_op("deconvolution"), input, weights); migraphx::shape output{migraphx::shape::float_type, {4, 3, 1, 1}};
throws_shape(migraphx::make_op("deconvolution"), input); expect_shape(output,
throws_shape( migraphx::make_op("convolution_backwards",
migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), {{"padding", {1, 1}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input); input,
weights);
}
migraphx::shape input_1d{migraphx::shape::float_type, {4, 4, 1}}; TEST_CASE(convolution_backwards_2stride)
migraphx::shape output_1d{migraphx::shape::float_type, {4, 3, 3}}; {
migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}}; migraphx::shape input{migraphx::shape::float_type, {4, 4, 4, 4}};
expect_shape( migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
output_1d, migraphx::shape output{migraphx::shape::float_type, {4, 3, 9, 9}};
migraphx::make_op("deconvolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), expect_shape(output,
input_1d, migraphx::make_op("convolution_backwards",
weights_1d); {{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
input,
weights);
}
TEST_CASE(convolution_backwards_2dilation)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 4, 4}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {4, 3, 8, 8}};
expect_shape(output,
migraphx::make_op("convolution_backwards",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {2, 2}}}),
input,
weights);
}
TEST_CASE(convolution_backwards_3d)
{
migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}}; migraphx::shape input_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; migraphx::shape output_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}}; migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
expect_shape( expect_shape(
output_3d, output_3d,
migraphx::make_op("deconvolution", migraphx::make_op("convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}), {{"padding", {0, 0, 0}}, {"stride", {1, 1, 1}}, {"dilation", {1, 1, 1}}}),
input_3d, input_3d,
weights_3d); weights_3d);
} }
TEST_CASE(convolution_backwards_channel_mismatch)
{
migraphx::shape input{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape weights{migraphx::shape::float_type, {3, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution_backwards"), input, weights);
}
TEST_CASE(convolution_backwards_dyn_batch_2d)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {1, 1}, {1, 1}}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}, {3, 3}}};
expect_shape(output, migraphx::make_op("convolution_backwards"), input, weights);
}
TEST_CASE(convolution_backwards_dyn_img_2d)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 1}, {4, 4}, {1, 5}, {1, 5}}};
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape output{migraphx::shape::float_type, {{1, 1}, {3, 3}, {3, 7}, {3, 7}}};
expect_shape(output, migraphx::make_op("convolution_backwards"), input, weights);
}
TEST_CASE(convolution_backwards_dyn_kernel_2d)
{
migraphx::shape input{migraphx::shape::float_type, {1, 4, 1, 1}};
migraphx::shape weights{migraphx::shape::float_type, {{4, 4}, {3, 3}, {2, 6}, {2, 6}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6}, {2, 6}}};
expect_shape(output, migraphx::make_op("convolution_backwards"), input, weights);
}
TEST_CASE(dimensions_of0)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 2, 1}};
migraphx::shape output{migraphx::shape::int64_type, {4}};
expect_shape(output, migraphx::make_op("dimensions_of", {{"end", 4}}), input);
}
TEST_CASE(dimensions_of1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 2, 1}};
migraphx::shape output{migraphx::shape::int64_type, {2}};
expect_shape(output, migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), input);
}
TEST_CASE(dimensions_of2)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, {2}}, {2, 4}, {2, 4}, {1, 6, {2}}}};
migraphx::shape output{migraphx::shape::int64_type, {2}};
expect_shape(output, migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), input);
}
TEST_CASE(dimensions_of_error0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, {2}}, {2, 4}}};
throws_shape(migraphx::make_op("dimensions_of", {{"start", 3}, {"end", 3}}), input);
}
TEST_CASE(dimensions_of_error1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, {2}}, {2, 4}}};
throws_shape(migraphx::make_op("dimensions_of", {{"start", 3}, {"end", 0}}), input);
}
TEST_CASE(dot_ndim_error0) TEST_CASE(dot_ndim_error0)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {5}}; migraphx::shape s_m1{migraphx::shape::float_type, {5}};
...@@ -1134,7 +1240,7 @@ TEST_CASE(inconsistent_attr_shape) ...@@ -1134,7 +1240,7 @@ TEST_CASE(inconsistent_attr_shape)
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input, input,
weights); weights);
throws_shape(migraphx::make_op("deconvolution", throws_shape(migraphx::make_op("convolution_backwards",
{{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}), {{"padding", {1, 1}}, {"stride", {2}}, {"dilation", {3, 3, 3}}}),
input, input,
weights); weights);
......
...@@ -33,8 +33,8 @@ def test_conv_relu(): ...@@ -33,8 +33,8 @@ def test_conv_relu():
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx") p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p) print(p)
print("Compiling ...") print("Compiling ...")
# set offload_copy, fast_match and exhaustive_tune to true # set offload_copy, fast_match to true
p.compile(migraphx.get_target("gpu"), True, True, True) p.compile(migraphx.get_target("gpu"), True, True)
print(p) print(p)
params = {} params = {}
......
...@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph) ...@@ -379,10 +379,7 @@ TEST_CASE(fp16_subgraph)
auto create_fp16_program = [] { auto create_fp16_program = [] {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sd{migraphx::shape::float_type, {1}}; migraphx::shape sd{migraphx::shape::half_type, {1}};
auto l1 = mm->add_literal(migraphx::literal(sd, {1}));
auto l2 = mm->add_literal(migraphx::literal(sd, {2}));
auto l3 = mm->add_literal(migraphx::literal(sd, {3}));
migraphx::shape sx{migraphx::shape::float_type, {1, 4}}; migraphx::shape sx{migraphx::shape::float_type, {1, 4}};
migraphx::shape sy{migraphx::shape::float_type, {3, 4}}; migraphx::shape sy{migraphx::shape::float_type, {3, 4}};
migraphx::shape sc{migraphx::shape::bool_type}; migraphx::shape sc{migraphx::shape::bool_type};
...@@ -390,17 +387,15 @@ TEST_CASE(fp16_subgraph) ...@@ -390,17 +387,15 @@ TEST_CASE(fp16_subgraph)
auto x = mm->add_parameter("x", sx); auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy); auto y = mm->add_parameter("y", sy);
auto* then_mod = p.create_module("If_6_if"); auto* then_mod = p.create_module("If_6_if");
auto hl1 = then_mod->add_instruction( auto hl2 = then_mod->add_literal(migraphx::literal(sd, {2}));
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l1); auto hl1 = then_mod->add_literal(migraphx::literal(sd, {1}));
auto mhl1 = then_mod->add_instruction( auto mhl1 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl1);
auto hx = then_mod->add_instruction( auto hx = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1); auto ad = then_mod->add_instruction(migraphx::make_op("add"), hx, mhl1);
auto fad = then_mod->add_instruction( auto fad = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad); migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), ad);
auto hl2 = then_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2);
auto mhl2 = then_mod->add_instruction( auto mhl2 = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2); migraphx::make_op("multibroadcast", {{"out_lens", {3, 4}}}), hl2);
auto hy1 = then_mod->add_instruction( auto hy1 = then_mod->add_instruction(
...@@ -411,9 +406,8 @@ TEST_CASE(fp16_subgraph) ...@@ -411,9 +406,8 @@ TEST_CASE(fp16_subgraph)
then_mod->add_return({fad, fmu, mu}); then_mod->add_return({fad, fmu, mu});
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto hl3 = else_mod->add_instruction( auto hl3 = else_mod->add_literal(migraphx::literal(sd, {3}));
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l3); auto mhl3 = else_mod->add_instruction(
auto mhl3 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3); migraphx::make_op("multibroadcast", {{"out_lens", {1, 4}}}), hl3);
auto hx2 = else_mod->add_instruction( auto hx2 = else_mod->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
...@@ -1020,7 +1014,7 @@ TEST_CASE(target_copy) ...@@ -1020,7 +1014,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result; std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result); run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify_range(ref_result, orig_result)); EXPECT(migraphx::verify::verify_range(ref_result, orig_result));
} }
} }
...@@ -1084,7 +1078,7 @@ TEST_CASE(int8_quantization_dot) ...@@ -1084,7 +1078,7 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000));
} }
} }
...@@ -1129,7 +1123,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1129,7 +1123,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result); run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result));
} }
} }
...@@ -1281,7 +1275,7 @@ TEST_CASE(test_op_capture) ...@@ -1281,7 +1275,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec)); EXPECT(migraphx::verify::verify_range(vec, cap_vec));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors) ...@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std::vector<float> results_vector(64); std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol)); EXPECT(migraphx::verify::verify_range(results_vector, sol));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -80,7 +80,7 @@ void dot_2d_test() ...@@ -80,7 +80,7 @@ void dot_2d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
...@@ -131,7 +131,7 @@ void dot_4d_test() ...@@ -131,7 +131,7 @@ void dot_4d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(dot_4d_test<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test) ...@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test0) TEST_CASE(dot_3D_C_test0)
...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0) ...@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_3D_C_test1) TEST_CASE(dot_3D_C_test1)
...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1) ...@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-0.95536130, -0.95536130,
2.27996211}; 2.27996211};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_test1) TEST_CASE(dot_4D_test1)
...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1) ...@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164, -0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906}; 3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_test) TEST_CASE(dot_4D_alpha_beta_test)
...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test) ...@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_4D_alpha_beta_C_test) TEST_CASE(dot_4D_alpha_beta_C_test)
...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test) ...@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_range(m, m_res));
} }
TEST_CASE(dot_2D_C_test0) TEST_CASE(dot_2D_C_test0)
...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0) ...@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product) ...@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm) ...@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm) ...@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm) ...@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm) ...@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv) ...@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv) ...@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv) ...@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1) ...@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1) ...@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2) ...@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2) ...@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2) ...@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2) ...@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(dot_dyn_4D_test) TEST_CASE(dot_dyn_4D_test)
...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range(c, results_vector));
} }
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4) ...@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general) ...@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general) ...@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
{ {
...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch) ...@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify::verify_range(m, gold));
} }
} }
......
...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(argmin_test_nonstd_shape) TEST_CASE(argmin_test_nonstd_shape)
...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec));
} }
TEST_CASE(isnan_broadcast_test) TEST_CASE(isnan_broadcast_test)
...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test) ...@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1}; std::vector<float> correct = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_range(results_vector, correct));
} }
TEST_CASE(squeeze_transpose_test) TEST_CASE(squeeze_transpose_test)
......
This diff is collapsed.
This diff is collapsed.
...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2.compile(migraphx::make_target("ref")); p2.compile(migraphx::make_target("ref"));
auto result1 = p1.eval({}).back(); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
visit_all(result1, visit_all(result1, result2)(
result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); }); [&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); });
}; };
test_rewrite_pooling(migraphx::op::pooling_mode::max, test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target) ...@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_range(results_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -228,6 +228,15 @@ TEST_CASE(test_shape_dynamic_errors) ...@@ -228,6 +228,15 @@ TEST_CASE(test_shape_dynamic_errors)
EXPECT(test::throws([&] { s.index(std::vector<std::size_t>{0, 1}); })); EXPECT(test::throws([&] { s.index(std::vector<std::size_t>{0, 1}); }));
EXPECT(test::throws([&] { s.with_lens({3, 5}); })); EXPECT(test::throws([&] { s.with_lens({3, 5}); }));
EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); })); EXPECT(test::throws([&] { s.with_lens(shape::float_type, {3, 5}); }));
EXPECT(test::throws([&] { s.lens(); }));
EXPECT(test::throws([&] { s.strides(); }));
}
TEST_CASE(test_shape_static_dyn_dim_error)
{
using migraphx::shape;
migraphx::shape s{shape::float_type, {2, 3, 4}};
EXPECT(test::throws([&] { s.dyn_dims(); }));
} }
TEST_CASE(test_shape_dynamic_serialize) TEST_CASE(test_shape_dynamic_serialize)
...@@ -947,13 +956,13 @@ TEST_CASE(test_with_type) ...@@ -947,13 +956,13 @@ TEST_CASE(test_with_type)
TEST_CASE(test_multi_index) TEST_CASE(test_multi_index)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}}; migraphx::shape s{migraphx::shape::float_type, {2, 4, 6}};
EXPECT(migraphx::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(0), std::vector<size_t>{0, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4})); EXPECT(migraphx::verify::verify_range(s.multi(4), std::vector<size_t>{0, 0, 4}));
EXPECT(migraphx::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(6), std::vector<size_t>{0, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2})); EXPECT(migraphx::verify::verify_range(s.multi(8), std::vector<size_t>{0, 1, 2}));
EXPECT(migraphx::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0})); EXPECT(migraphx::verify::verify_range(s.multi(24), std::vector<size_t>{1, 0, 0}));
EXPECT(migraphx::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0})); EXPECT(migraphx::verify::verify_range(s.multi(30), std::vector<size_t>{1, 1, 0}));
EXPECT(migraphx::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4})); EXPECT(migraphx::verify::verify_range(s.multi(34), std::vector<size_t>{1, 1, 4}));
} }
TEST_CASE(find_permutation_2d_standard) TEST_CASE(find_permutation_2d_standard)
......
...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness) ...@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back(); auto result2 = p2.eval({{"input", input}, {"weights", weights}}).back();
std::vector<float> rv2(16); std::vector<float> rv2(16);
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
TEST_CASE(dot_correctness) TEST_CASE(dot_correctness)
...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness) ...@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto result2 = p2.eval({{"a", a}, {"b", b}}).back(); auto result2 = p2.eval({{"a", a}, {"b", b}}).back();
std::vector<float> rv2(sh3.elements()); std::vector<float> rv2(sh3.elements());
result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { rv2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(rv1, rv2)); EXPECT(migraphx::verify::verify_range(rv1, rv2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -357,6 +357,106 @@ TEST_CASE(nop_convert) ...@@ -357,6 +357,106 @@ TEST_CASE(nop_convert)
EXPECT(std::distance(m.begin(), m.end()) == n - 1); EXPECT(std::distance(m.begin(), m.end()) == n - 1);
} }
TEST_CASE(nested_reshape)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4, 5, 6, 7}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4, 5, 42}}}), x);
auto rshp2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12, 5, 42}}}), rshp1);
auto rshp3 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12, 5, 42}}}), rshp2);
auto rshp4 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 60, 42}}}), rshp3);
auto rshp5 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {120, 42}}}), rshp4);
auto rshp6 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), rshp5);
m1.add_return({rshp6});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_reshape_contiguous)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4, 5, 6, 7}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4, 5, 42}}}), x);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), rshp1);
auto rshp2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12, 5, 42}}}), c1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), rshp2);
auto rshp3 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12, 5, 42}}}), c2);
auto c3 = m1.add_instruction(migraphx::make_op("contiguous"), rshp3);
auto rshp4 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 60, 42}}}), c3);
auto c4 = m1.add_instruction(migraphx::make_op("contiguous"), rshp4);
auto rshp5 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {120, 42}}}), c4);
auto c5 = m1.add_instruction(migraphx::make_op("contiguous"), rshp5);
auto rshp6 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), c5);
m1.add_return({rshp6});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_reshape_squeeze)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12}}}), x);
auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), rshp);
m1.add_return({squeeze});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_squeeze_reshape)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), x);
auto rshp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), squeeze);
m1.add_return({rshp});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(concat_multibroadcasts1) TEST_CASE(concat_multibroadcasts1)
{ {
// Broadcasted batch dim, new axis < old axis // Broadcasted batch dim, new axis < old axis
......
...@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test) ...@@ -196,7 +196,6 @@ TEST_CASE(batchnorm_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-4f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test) ...@@ -204,11 +203,11 @@ TEST_CASE(batchnorm_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnorm_test.pb", true); auto prog = optimize_tf("batchnorm_test.pb", true);
...@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test) ...@@ -227,7 +226,6 @@ TEST_CASE(batchnorm_half_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::half_type, {1e-4f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test) ...@@ -235,11 +233,11 @@ TEST_CASE(batchnorm_half_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnorm_half_test.pb", true); auto prog = optimize_tf("batchnorm_half_test.pb", true);
...@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test) ...@@ -258,7 +256,6 @@ TEST_CASE(batchnormv3_test)
std::vector<float> scale_data(32, 1.0); std::vector<float> scale_data(32, 1.0);
auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data); auto scale = mm->add_literal(migraphx::shape{migraphx::shape::float_type, {32}}, scale_data);
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}}); auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-6f}});
auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); auto usq_scale = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test) ...@@ -266,11 +263,11 @@ TEST_CASE(batchnormv3_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean); auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var); auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);
auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean}); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps}); auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto denom = add_common_op(*mm, migraphx::make_op("pow"), {var_eps, rt}); auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("div"), {numer, denom}); auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale}); auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias}); add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});
auto prog = optimize_tf("batchnormv3_test.pb", true); auto prog = optimize_tf("batchnormv3_test.pb", true);
......
...@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p, ...@@ -88,10 +88,31 @@ inline void compile_check(migraphx::program& p,
auto num = shapes.size(); auto num = shapes.size();
for(std::size_t i = 0; i < num; ++i) for(std::size_t i = 0; i < num; ++i)
{ {
if(p.get_output_shapes()[i].lens() != shapes[i].lens()) auto output_shape = p.get_output_shapes()[i];
if(output_shape.dynamic() and shapes[i].dynamic())
{
if(output_shape.dyn_dims() != shapes[i].dyn_dims())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its dynamic output dimensions");
}
}
else if(not(output_shape.dynamic() or shapes[i].dynamic()))
{
if(output_shape.lens() != shapes[i].lens())
{
std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name +
" alters its static output dimensions");
}
}
else
{ {
std::cout << ss.str() << std::endl; std::cout << ss.str() << std::endl;
throw std::runtime_error("Compiling program with " + name + " alters its shape"); throw std::runtime_error(
"Compiling program with " + name +
" alters its output dimensions (static shape vs dynamic shape)");
} }
} }
if(t.name() != "ref") if(t.name() != "ref")
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv : verify_program<test_deconv> struct test_convolution_backwards : verify_program<test_convolution_backwards>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv> ...@@ -37,7 +37,7 @@ struct test_deconv : verify_program<test_deconv>
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}});
mm->add_instruction(migraphx::make_op("deconvolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution_backwards"), input, weights);
return p; return p;
} }
}; };
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_deconv_1d : verify_program<test_deconv_1d> struct test_convolution_backwards_1d : verify_program<test_convolution_backwards_1d>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d> ...@@ -38,7 +38,7 @@ struct test_deconv_1d : verify_program<test_deconv_1d>
auto weights = auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}}); mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("deconvolution", migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input, input,
weights); weights);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment