Commit 13d14c66 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into dyn_resize_gather

parents f4e7d9d9 d1abf06f
...@@ -21,56 +21,42 @@ ...@@ -21,56 +21,42 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iostream> #include <migraphx/apply_alpha_beta.hpp>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp>
template <class T> template <class T>
void dot_2d_test() void dot_2d_test()
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814}; -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01, std::vector<T> b = {6.09568541e-01,
-6.10527007e-01, -6.10527007e-01,
3.66646462e-01, 3.66646462e-01,
1.18951101e-01, 1.18951101e-01,
5.58777432e-01, 5.58777432e-01,
-3.21296298e-01, -3.21296298e-01,
-5.95997198e-01, -5.95997198e-01,
-5.01425721e-01, -5.01425721e-01,
-2.84606807e-01, -2.84606807e-01,
-5.73673557e-01, -5.73673557e-01,
-8.99430260e-01, -8.99430260e-01,
-4.25103093e-01, -4.25103093e-01,
1.53027987e+00, 1.53027987e+00,
-3.81407415e-04, -3.81407415e-04,
-3.29650255e-01}; -3.29650255e-01};
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}}; migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {4, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}}; migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {5, 3}};
...@@ -80,7 +66,20 @@ void dot_2d_test() ...@@ -80,7 +66,20 @@ void dot_2d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(c, results_vector)); std::vector<T> gold = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify::verify_range_with_tolerance(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
TEST_CASE_REGISTER(dot_2d_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
...@@ -90,38 +89,38 @@ void dot_4d_test() ...@@ -90,38 +89,38 @@ void dot_4d_test()
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885, std::vector<T> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027, 1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632, -0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814}; -1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01, std::vector<T> b = {6.09568541e-01,
-6.10527007e-01, -6.10527007e-01,
3.66646462e-01, 3.66646462e-01,
1.18951101e-01, 1.18951101e-01,
5.58777432e-01, 5.58777432e-01,
-3.21296298e-01, -3.21296298e-01,
-5.95997198e-01, -5.95997198e-01,
-5.01425721e-01, -5.01425721e-01,
-2.84606807e-01, -2.84606807e-01,
-5.73673557e-01, -5.73673557e-01,
-8.99430260e-01, -8.99430260e-01,
-4.25103093e-01, -4.25103093e-01,
1.53027987e+00, 1.53027987e+00,
-3.81407415e-04, -3.81407415e-04,
-3.29650255e-01}; -3.29650255e-01};
std::vector<float> c = {-1.56327541e+00, std::vector<T> gold = {-1.56327541e+00,
-7.09570140e-01, -7.09570140e-01,
-5.37424982e-01, -5.37424982e-01,
-2.22994831e-01, -2.22994831e-01,
-2.15586437e+00, -2.15586437e+00,
2.09177941e-03, 2.09177941e-03,
-1.47279677e+00, -1.47279677e+00,
2.02627040e-01, 2.02627040e-01,
-6.04527691e-01, -6.04527691e-01,
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}}; migraphx::shape a_shape{migraphx::shape::get_type<T>{}, {1, 1, 4, 5}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::get_type<T>{}, {1, 1, 5, 3}};
...@@ -131,8 +130,10 @@ void dot_4d_test() ...@@ -131,8 +130,10 @@ void dot_4d_test()
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<T> results_vector; std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_range_with_tolerance(
results_vector, migraphx::verify::expected{gold}, migraphx::verify::tolerance{9e-6}));
} }
TEST_CASE_REGISTER(dot_4d_test<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
...@@ -169,24 +170,24 @@ TEST_CASE(dot_3D_test) ...@@ -169,24 +170,24 @@ TEST_CASE(dot_3D_test)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394, std::vector<float> gold = {0.18208394,
-0.49276402, -0.49276402,
0.87189133, 0.87189133,
0.75150114, 0.75150114,
-0.55909610, -0.55909610,
1.00521735, 1.00521735,
-0.95536130, -0.95536130,
2.27996211, 2.27996211,
0.06239879, 0.06239879,
0.74700068, 0.74700068,
-0.01570983, -0.01570983,
-0.85920856, -0.85920856,
-0.59070835, -0.59070835,
-1.70729902, -1.70729902,
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_3D_C_test0) TEST_CASE(dot_3D_C_test0)
...@@ -245,24 +246,24 @@ TEST_CASE(dot_3D_C_test0) ...@@ -245,24 +246,24 @@ TEST_CASE(dot_3D_C_test0)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394, std::vector<float> gold = {0.18208394,
-0.49276402, -0.49276402,
0.87189133, 0.87189133,
0.75150114, 0.75150114,
-0.55909610, -0.55909610,
1.00521735, 1.00521735,
-0.95536130, -0.95536130,
2.27996211, 2.27996211,
0.06239879, 0.06239879,
0.74700068, 0.74700068,
-0.01570983, -0.01570983,
-0.85920856, -0.85920856,
-0.59070835, -0.59070835,
-1.70729902, -1.70729902,
0.40245487, 0.40245487,
1.80182751}; 1.80182751};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_3D_C_test1) TEST_CASE(dot_3D_C_test1)
...@@ -312,16 +313,16 @@ TEST_CASE(dot_3D_C_test1) ...@@ -312,16 +313,16 @@ TEST_CASE(dot_3D_C_test1)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.18208394, std::vector<float> gold = {0.18208394,
-0.49276402, -0.49276402,
0.87189133, 0.87189133,
0.75150114, 0.75150114,
-0.55909610, -0.55909610,
1.00521735, 1.00521735,
-0.95536130, -0.95536130,
2.27996211}; 2.27996211};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_4D_test1) TEST_CASE(dot_4D_test1)
...@@ -354,13 +355,13 @@ TEST_CASE(dot_4D_test1) ...@@ -354,13 +355,13 @@ TEST_CASE(dot_4D_test1)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170, std::vector<float> gold = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939, -0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406, 0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164, -0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906}; 3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_4D_alpha_beta_test) TEST_CASE(dot_4D_alpha_beta_test)
...@@ -408,13 +409,13 @@ TEST_CASE(dot_4D_alpha_beta_test) ...@@ -408,13 +409,13 @@ TEST_CASE(dot_4D_alpha_beta_test)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586, std::vector<float> gold = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650, 0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252, -0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_4D_alpha_beta_C_test) TEST_CASE(dot_4D_alpha_beta_C_test)
...@@ -460,13 +461,13 @@ TEST_CASE(dot_4D_alpha_beta_C_test) ...@@ -460,13 +461,13 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586, std::vector<float> gold = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650, 0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252, -0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824, -0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845}; 0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify::verify_range(m, m_res)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_2D_C_test0) TEST_CASE(dot_2D_C_test0)
...@@ -522,7 +523,7 @@ TEST_CASE(dot_2D_C_test0) ...@@ -522,7 +523,7 @@ TEST_CASE(dot_2D_C_test0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vv_inner_product_1) TEST_CASE(dot_vv_inner_product_1)
...@@ -558,7 +559,7 @@ TEST_CASE(dot_vv_inner_product_1) ...@@ -558,7 +559,7 @@ TEST_CASE(dot_vv_inner_product_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vv_inner_product_2) TEST_CASE(dot_vv_inner_product_2)
...@@ -596,7 +597,7 @@ TEST_CASE(dot_vv_inner_product_2) ...@@ -596,7 +597,7 @@ TEST_CASE(dot_vv_inner_product_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vm_1) TEST_CASE(dot_vm_1)
...@@ -631,7 +632,7 @@ TEST_CASE(dot_vm_1) ...@@ -631,7 +632,7 @@ TEST_CASE(dot_vm_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vm_2) TEST_CASE(dot_vm_2)
...@@ -668,7 +669,7 @@ TEST_CASE(dot_vm_2) ...@@ -668,7 +669,7 @@ TEST_CASE(dot_vm_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vm_3) TEST_CASE(dot_vm_3)
...@@ -714,7 +715,7 @@ TEST_CASE(dot_vm_3) ...@@ -714,7 +715,7 @@ TEST_CASE(dot_vm_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_vm_4) TEST_CASE(dot_vm_4)
...@@ -761,7 +762,7 @@ TEST_CASE(dot_vm_4) ...@@ -761,7 +762,7 @@ TEST_CASE(dot_vm_4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mv_1) TEST_CASE(dot_mv_1)
...@@ -798,7 +799,7 @@ TEST_CASE(dot_mv_1) ...@@ -798,7 +799,7 @@ TEST_CASE(dot_mv_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mv_2) TEST_CASE(dot_mv_2)
...@@ -837,7 +838,7 @@ TEST_CASE(dot_mv_2) ...@@ -837,7 +838,7 @@ TEST_CASE(dot_mv_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mv_3) TEST_CASE(dot_mv_3)
...@@ -881,7 +882,7 @@ TEST_CASE(dot_mv_3) ...@@ -881,7 +882,7 @@ TEST_CASE(dot_mv_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm1_1) TEST_CASE(dot_mm1_1)
...@@ -932,7 +933,7 @@ TEST_CASE(dot_mm1_1) ...@@ -932,7 +933,7 @@ TEST_CASE(dot_mm1_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm1_2) TEST_CASE(dot_mm1_2)
...@@ -985,7 +986,7 @@ TEST_CASE(dot_mm1_2) ...@@ -985,7 +986,7 @@ TEST_CASE(dot_mm1_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm2_1) TEST_CASE(dot_mm2_1)
...@@ -1027,7 +1028,7 @@ TEST_CASE(dot_mm2_1) ...@@ -1027,7 +1028,7 @@ TEST_CASE(dot_mm2_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm2_2) TEST_CASE(dot_mm2_2)
...@@ -1070,7 +1071,7 @@ TEST_CASE(dot_mm2_2) ...@@ -1070,7 +1071,7 @@ TEST_CASE(dot_mm2_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm2_3) TEST_CASE(dot_mm2_3)
...@@ -1119,7 +1120,7 @@ TEST_CASE(dot_mm2_3) ...@@ -1119,7 +1120,7 @@ TEST_CASE(dot_mm2_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_mm2_4) TEST_CASE(dot_mm2_4)
...@@ -1165,7 +1166,7 @@ TEST_CASE(dot_mm2_4) ...@@ -1165,7 +1166,7 @@ TEST_CASE(dot_mm2_4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(dot_dyn_2D_test) TEST_CASE(dot_dyn_2D_test)
...@@ -1205,19 +1206,19 @@ TEST_CASE(dot_dyn_2D_test) ...@@ -1205,19 +1206,19 @@ TEST_CASE(dot_dyn_2D_test)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00, std::vector<float> gold = {-1.56327541e+00,
-7.09570140e-01, -7.09570140e-01,
-5.37424982e-01, -5.37424982e-01,
-2.22994831e-01, -2.22994831e-01,
-2.15586437e+00, -2.15586437e+00,
2.09177941e-03, 2.09177941e-03,
-1.47279677e+00, -1.47279677e+00,
2.02627040e-01, 2.02627040e-01,
-6.04527691e-01, -6.04527691e-01,
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(dot_dyn_4D_test) TEST_CASE(dot_dyn_4D_test)
...@@ -1259,19 +1260,19 @@ TEST_CASE(dot_dyn_4D_test) ...@@ -1259,19 +1260,19 @@ TEST_CASE(dot_dyn_4D_test)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00, std::vector<float> gold = {-1.56327541e+00,
-7.09570140e-01, -7.09570140e-01,
-5.37424982e-01, -5.37424982e-01,
-2.22994831e-01, -2.22994831e-01,
-2.15586437e+00, -2.15586437e+00,
2.09177941e-03, 2.09177941e-03,
-1.47279677e+00, -1.47279677e+00,
2.02627040e-01, 2.02627040e-01,
-6.04527691e-01, -6.04527691e-01,
-1.29885596e+00, -1.29885596e+00,
2.16294914e+00, 2.16294914e+00,
-1.48101497e-01}; -1.48101497e-01};
EXPECT(migraphx::verify::verify_range(c, results_vector)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(quant_dot_2args_multi4_1) TEST_CASE(quant_dot_2args_multi4_1)
...@@ -1298,7 +1299,7 @@ TEST_CASE(quant_dot_2args_multi4_1) ...@@ -1298,7 +1299,7 @@ TEST_CASE(quant_dot_2args_multi4_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_multi4_2) TEST_CASE(quant_dot_2args_multi4_2)
...@@ -1326,7 +1327,7 @@ TEST_CASE(quant_dot_2args_multi4_2) ...@@ -1326,7 +1327,7 @@ TEST_CASE(quant_dot_2args_multi4_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_multi4_3) TEST_CASE(quant_dot_2args_multi4_3)
...@@ -1354,7 +1355,7 @@ TEST_CASE(quant_dot_2args_multi4_3) ...@@ -1354,7 +1355,7 @@ TEST_CASE(quant_dot_2args_multi4_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_multi4_4) TEST_CASE(quant_dot_2args_multi4_4)
...@@ -1383,7 +1384,7 @@ TEST_CASE(quant_dot_2args_multi4_4) ...@@ -1383,7 +1384,7 @@ TEST_CASE(quant_dot_2args_multi4_4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_general_1) TEST_CASE(quant_dot_2args_general_1)
...@@ -1408,7 +1409,7 @@ TEST_CASE(quant_dot_2args_general_1) ...@@ -1408,7 +1409,7 @@ TEST_CASE(quant_dot_2args_general_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_general_2) TEST_CASE(quant_dot_2args_general_2)
...@@ -1435,7 +1436,7 @@ TEST_CASE(quant_dot_2args_general_2) ...@@ -1435,7 +1436,7 @@ TEST_CASE(quant_dot_2args_general_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_general_3) TEST_CASE(quant_dot_2args_general_3)
...@@ -1463,7 +1464,7 @@ TEST_CASE(quant_dot_2args_general_3) ...@@ -1463,7 +1464,7 @@ TEST_CASE(quant_dot_2args_general_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_2args_general_4) TEST_CASE(quant_dot_2args_general_4)
...@@ -1491,7 +1492,7 @@ TEST_CASE(quant_dot_2args_general_4) ...@@ -1491,7 +1492,7 @@ TEST_CASE(quant_dot_2args_general_4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_general_1) TEST_CASE(quant_dot_3args_general_1)
...@@ -1521,7 +1522,7 @@ TEST_CASE(quant_dot_3args_general_1) ...@@ -1521,7 +1522,7 @@ TEST_CASE(quant_dot_3args_general_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_general_2) TEST_CASE(quant_dot_3args_general_2)
...@@ -1549,7 +1550,7 @@ TEST_CASE(quant_dot_3args_general_2) ...@@ -1549,7 +1550,7 @@ TEST_CASE(quant_dot_3args_general_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_general_3) TEST_CASE(quant_dot_3args_general_3)
...@@ -1580,7 +1581,7 @@ TEST_CASE(quant_dot_3args_general_3) ...@@ -1580,7 +1581,7 @@ TEST_CASE(quant_dot_3args_general_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_general_4) TEST_CASE(quant_dot_3args_general_4)
...@@ -1611,7 +1612,7 @@ TEST_CASE(quant_dot_3args_general_4) ...@@ -1611,7 +1612,7 @@ TEST_CASE(quant_dot_3args_general_4)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_general_5) TEST_CASE(quant_dot_3args_general_5)
...@@ -1643,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general_5) ...@@ -1643,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general_5)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_batch_1) TEST_CASE(quant_dot_3args_batch_1)
...@@ -1677,7 +1678,7 @@ TEST_CASE(quant_dot_3args_batch_1) ...@@ -1677,7 +1678,7 @@ TEST_CASE(quant_dot_3args_batch_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
TEST_CASE(quant_dot_3args_batch_2) TEST_CASE(quant_dot_3args_batch_2)
...@@ -1716,5 +1717,5 @@ TEST_CASE(quant_dot_3args_batch_2) ...@@ -1716,5 +1717,5 @@ TEST_CASE(quant_dot_3args_batch_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> m; std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(m, gold)); EXPECT(migraphx::verify::verify_rms_range(m, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(elu_test) ...@@ -45,7 +45,7 @@ TEST_CASE(elu_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)}; std::vector<float> gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(elu_dyn_test) TEST_CASE(elu_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(elu_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(elu_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)}; std::vector<float> gold{elu(alpha, -1), elu(alpha, 2), elu(alpha, -3), elu(alpha, 4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(erf_test) ...@@ -45,7 +45,7 @@ TEST_CASE(erf_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(erf_dyn_test) TEST_CASE(erf_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(erf_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(erf_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return erff(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(exp_test) ...@@ -45,7 +45,7 @@ TEST_CASE(exp_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(exp_dyn_test) TEST_CASE(exp_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(exp_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(exp_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return expf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(fill_static_int)
{
// Note this case can be simplified to a literal
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape lit_shape{migraphx::shape::int64_type, {1}, {0}};
std::vector<int64_t> lit_data = {3};
auto l = mm->add_literal(migraphx::literal{lit_shape, lit_data});
migraphx::shape data_shape{migraphx::shape::int64_type, {3, 4, 4}};
auto input = mm->add_parameter("x", data_shape);
mm->add_instruction(migraphx::make_op("fill"), l, input);
p.compile(migraphx::make_target("ref"));
std::vector<int64_t> input_data(48);
migraphx::parameter_map params;
params["x"] = migraphx::argument(data_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<int64_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold(48, 3);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(fill_dyn_float)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape lit_shape{migraphx::shape::float_type, {1}, {0}};
std::vector<float> lit_data = {7.36};
auto l = mm->add_literal(migraphx::literal{lit_shape, lit_data});
migraphx::shape data_shape{migraphx::shape::float_type,
{{1, 4}, {4, 8, {4, 6, 8}}, {4, 8, {4, 6, 8}}}};
auto input = mm->add_parameter("x", data_shape);
mm->add_instruction(migraphx::make_op("fill"), l, input);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data(72);
migraphx::parameter_map params;
migraphx::shape static_shape = {migraphx::shape::float_type, {2, 6, 6}};
params["x"] = migraphx::argument(static_shape, input_data.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(72, 7.36);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
TEST_CASE(fill_var_default_value)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape dv_shape{migraphx::shape::int64_type, {1}, {0}};
auto dv = mm->add_parameter("dv", dv_shape);
migraphx::shape data_shape{migraphx::shape::int64_type, {3, 4, 4}};
auto input = mm->add_parameter("x", data_shape);
mm->add_instruction(migraphx::make_op("fill"), dv, input);
p.compile(migraphx::make_target("ref"));
std::vector<int64_t> dv_data = {2};
std::vector<int64_t> input_data(48);
migraphx::parameter_map params;
params["x"] = migraphx::argument(data_shape, input_data.data());
params["dv"] = migraphx::argument(dv_shape, dv_data.data());
auto result = p.eval(params).back();
std::vector<int64_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold(48, 2);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
}
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(floor_test) ...@@ -45,7 +45,7 @@ TEST_CASE(floor_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(floor_dyn_test) TEST_CASE(floor_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(floor_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(floor_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return floor(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(fmod_test) ...@@ -45,7 +45,7 @@ TEST_CASE(fmod_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1, 0, -3}; std::vector<float> gold{-1, 0, -3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fmod_dyn_test) TEST_CASE(fmod_dyn_test)
...@@ -73,7 +73,7 @@ TEST_CASE(fmod_dyn_test) ...@@ -73,7 +73,7 @@ TEST_CASE(fmod_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1, 0, -3}; std::vector<float> gold{-1, 0, -3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fmod_float_test) TEST_CASE(fmod_float_test)
...@@ -92,5 +92,5 @@ TEST_CASE(fmod_float_test) ...@@ -92,5 +92,5 @@ TEST_CASE(fmod_float_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1.2f, 0.5f, -3.3f}; std::vector<float> gold{-1.2f, 0.5f, -3.3f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -52,7 +52,7 @@ TEST_CASE(gather_non_std_test) ...@@ -52,7 +52,7 @@ TEST_CASE(gather_non_std_test)
0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f, 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f, 0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
std::vector<float> res_data; std::vector<float> res_data;
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
} }
...@@ -75,7 +75,7 @@ TEST_CASE(gather_test_1) ...@@ -75,7 +75,7 @@ TEST_CASE(gather_test_1)
std::vector<float> res_data(4 * 5); std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_2) TEST_CASE(gather_test_2)
...@@ -97,7 +97,7 @@ TEST_CASE(gather_test_2) ...@@ -97,7 +97,7 @@ TEST_CASE(gather_test_2)
std::vector<float> res_data(4 * 5); std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_3) TEST_CASE(gather_test_3)
...@@ -119,7 +119,7 @@ TEST_CASE(gather_test_3) ...@@ -119,7 +119,7 @@ TEST_CASE(gather_test_3)
std::vector<float> res_data(4 * 5); std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_4) TEST_CASE(gather_test_4)
...@@ -141,7 +141,7 @@ TEST_CASE(gather_test_4) ...@@ -141,7 +141,7 @@ TEST_CASE(gather_test_4)
std::vector<float> res_data(4 * 5); std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f}; std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_5) TEST_CASE(gather_test_5)
...@@ -164,7 +164,7 @@ TEST_CASE(gather_test_5) ...@@ -164,7 +164,7 @@ TEST_CASE(gather_test_5)
std::vector<float> res_data{}; std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f}; std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_6) TEST_CASE(gather_test_6)
...@@ -187,7 +187,7 @@ TEST_CASE(gather_test_6) ...@@ -187,7 +187,7 @@ TEST_CASE(gather_test_6)
std::vector<float> res_data{}; std::vector<float> res_data{};
std::vector<float> golden = {0.5f, 3.5f, 6.5f}; std::vector<float> golden = {0.5f, 3.5f, 6.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_test_7) TEST_CASE(gather_test_7)
...@@ -210,7 +210,7 @@ TEST_CASE(gather_test_7) ...@@ -210,7 +210,7 @@ TEST_CASE(gather_test_7)
std::vector<float> res_data{}; std::vector<float> res_data{};
std::vector<float> golden = {0.5f}; std::vector<float> golden = {0.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
} }
TEST_CASE(gather_dyn_test0) TEST_CASE(gather_dyn_test0)
...@@ -243,7 +243,7 @@ TEST_CASE(gather_dyn_test0) ...@@ -243,7 +243,7 @@ TEST_CASE(gather_dyn_test0)
std::vector<int> gold = {1, 2, 4, 5}; std::vector<int> gold = {1, 2, 4, 5};
std::vector<int> results_vector(2 * 1 * 2); std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {2, 1, 2}}; migraphx::shape sfinal{migraphx::shape::int32_type, {2, 1, 2}};
EXPECT(result.get_shape() == sfinal); EXPECT(result.get_shape() == sfinal);
} }
...@@ -280,7 +280,7 @@ TEST_CASE(gather_dyn_test1) ...@@ -280,7 +280,7 @@ TEST_CASE(gather_dyn_test1)
std::vector<int> results_vector(1 * 2 * 4); std::vector<int> results_vector(1 * 2 * 4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
migraphx::shape sfinal{migraphx::shape::int32_type, {1, 2, 4}}; migraphx::shape sfinal{migraphx::shape::int32_type, {1, 2, 4}};
EXPECT(result.get_shape() == sfinal); EXPECT(result.get_shape() == sfinal);
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -52,7 +52,7 @@ TEST_CASE(gathernd_test_1) ...@@ -52,7 +52,7 @@ TEST_CASE(gathernd_test_1)
std::vector<float> gold{0, 3}; std::vector<float> gold{0, 3};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_test_2) TEST_CASE(gathernd_test_2)
...@@ -77,7 +77,7 @@ TEST_CASE(gathernd_test_2) ...@@ -77,7 +77,7 @@ TEST_CASE(gathernd_test_2)
std::vector<float> gold{2, 3, 0, 1}; std::vector<float> gold{2, 3, 0, 1};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_test_3) TEST_CASE(gathernd_test_3)
...@@ -102,7 +102,7 @@ TEST_CASE(gathernd_test_3) ...@@ -102,7 +102,7 @@ TEST_CASE(gathernd_test_3)
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_test_4) TEST_CASE(gathernd_test_4)
...@@ -128,7 +128,7 @@ TEST_CASE(gathernd_test_4) ...@@ -128,7 +128,7 @@ TEST_CASE(gathernd_test_4)
std::vector<float> gold{0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23}; std::vector<float> gold{0, 1, 2, 3, 4, 5, 18, 19, 20, 21, 22, 23};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_test_5) TEST_CASE(gathernd_test_5)
...@@ -155,7 +155,7 @@ TEST_CASE(gathernd_test_5) ...@@ -155,7 +155,7 @@ TEST_CASE(gathernd_test_5)
std::vector<float> gold{0, 4, 8, 11, 13, 15}; std::vector<float> gold{0, 4, 8, 11, 13, 15};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_test_6) TEST_CASE(gathernd_test_6)
...@@ -215,7 +215,7 @@ TEST_CASE(gathernd_dynamic0) ...@@ -215,7 +215,7 @@ TEST_CASE(gathernd_dynamic0)
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_dynamic1) TEST_CASE(gathernd_dynamic1)
...@@ -251,7 +251,7 @@ TEST_CASE(gathernd_dynamic1) ...@@ -251,7 +251,7 @@ TEST_CASE(gathernd_dynamic1)
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_dynamic2) TEST_CASE(gathernd_dynamic2)
...@@ -287,7 +287,7 @@ TEST_CASE(gathernd_dynamic2) ...@@ -287,7 +287,7 @@ TEST_CASE(gathernd_dynamic2)
std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5}; std::vector<float> gold{3, 4, 5, 0, 1, 2, 0, 1, 2, 3, 4, 5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_dynamic3) TEST_CASE(gathernd_dynamic3)
...@@ -323,7 +323,7 @@ TEST_CASE(gathernd_dynamic3) ...@@ -323,7 +323,7 @@ TEST_CASE(gathernd_dynamic3)
std::vector<float> res_data{}; std::vector<float> res_data{};
std::vector<float> gold{1, 0, 3, 4}; std::vector<float> gold{1, 0, 3, 4};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_dynamic4) TEST_CASE(gathernd_dynamic4)
...@@ -358,7 +358,7 @@ TEST_CASE(gathernd_dynamic4) ...@@ -358,7 +358,7 @@ TEST_CASE(gathernd_dynamic4)
std::vector<float> res_data{}; std::vector<float> res_data{};
std::vector<float> gold{5}; std::vector<float> gold{5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_negative_index_test_1) TEST_CASE(gathernd_negative_index_test_1)
...@@ -383,7 +383,7 @@ TEST_CASE(gathernd_negative_index_test_1) ...@@ -383,7 +383,7 @@ TEST_CASE(gathernd_negative_index_test_1)
std::vector<float> gold{2, 3, 0, 1}; std::vector<float> gold{2, 3, 0, 1};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res_data, gold)); EXPECT(migraphx::verify::verify_rms_range(res_data, gold));
} }
TEST_CASE(gathernd_negative_index_test_2) TEST_CASE(gathernd_negative_index_test_2)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -40,14 +40,14 @@ TEST_CASE(im2col_3x3_no_pad_identity_test) ...@@ -40,14 +40,14 @@ TEST_CASE(im2col_3x3_no_pad_identity_test)
std::size_t channels = 1; std::size_t channels = 1;
std::vector<int32_t> weights(channels * f[0] * f[1]); std::vector<int32_t> weights(channels * f[0] * f[1]);
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> gold(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(gold.begin(), gold.end(), 0);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = mm->add_literal(migraphx::literal{s_image, input}); auto l_image = mm->add_literal(migraphx::literal{s_image, gold});
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("im2col", migraphx::make_op("im2col",
...@@ -61,7 +61,7 @@ TEST_CASE(im2col_3x3_no_pad_identity_test) ...@@ -61,7 +61,7 @@ TEST_CASE(im2col_3x3_no_pad_identity_test)
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, input)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(im2col_3x3_no_pad_test) TEST_CASE(im2col_3x3_no_pad_test)
...@@ -91,14 +91,14 @@ TEST_CASE(im2col_3x3_no_pad_test) ...@@ -91,14 +91,14 @@ TEST_CASE(im2col_3x3_no_pad_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int> correct = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11, std::vector<int> gold = {0, 1, 2, 4, 5, 6, 8, 9, 10, 1, 2, 3, 5, 6, 7, 9, 10, 11,
4, 5, 6, 8, 9, 10, 12, 13, 14, 5, 6, 7, 9, 10, 11, 13, 14, 15}; 4, 5, 6, 8, 9, 10, 12, 13, 14, 5, 6, 7, 9, 10, 11, 13, 14, 15};
std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1;
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(im2col_3x3_stride_2_no_pad_test) TEST_CASE(im2col_3x3_stride_2_no_pad_test)
...@@ -128,15 +128,15 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test) ...@@ -128,15 +128,15 @@ TEST_CASE(im2col_3x3_stride_2_no_pad_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int> correct = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4, std::vector<int> gold = {0, 1, 2, 6, 7, 8, 12, 13, 14, 2, 3, 4,
8, 9, 10, 14, 15, 16, 12, 13, 14, 18, 19, 20, 8, 9, 10, 14, 15, 16, 12, 13, 14, 18, 19, 20,
24, 25, 26, 14, 15, 16, 20, 21, 22, 26, 27, 28}; 24, 25, 26, 14, 15, 16, 20, 21, 22, 26, 27, 28};
std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1;
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(im2col_3x3_with_channels_identity_test) TEST_CASE(im2col_3x3_with_channels_identity_test)
...@@ -149,14 +149,14 @@ TEST_CASE(im2col_3x3_with_channels_identity_test) ...@@ -149,14 +149,14 @@ TEST_CASE(im2col_3x3_with_channels_identity_test)
std::size_t channels = 2; std::size_t channels = 2;
std::vector<int32_t> weights(channels * f[0] * f[1]); std::vector<int32_t> weights(channels * f[0] * f[1]);
std::vector<int32_t> input(channels * size[0] * size[1]); std::vector<int32_t> gold(channels * size[0] * size[1]);
std::iota(input.begin(), input.end(), 0); std::iota(gold.begin(), gold.end(), 0);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}}; migraphx::shape s_image{migraphx::shape::int32_type, {1, channels, size[0], size[1]}};
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}}; migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_image = mm->add_literal(migraphx::literal{s_image, input}); auto l_image = mm->add_literal(migraphx::literal{s_image, gold});
auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights}); auto l_weights = mm->add_literal(migraphx::literal{s_weights, weights});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("im2col", migraphx::make_op("im2col",
...@@ -170,7 +170,7 @@ TEST_CASE(im2col_3x3_with_channels_identity_test) ...@@ -170,7 +170,7 @@ TEST_CASE(im2col_3x3_with_channels_identity_test)
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, input)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(im2col_3x3_with_padding_test) TEST_CASE(im2col_3x3_with_padding_test)
...@@ -200,12 +200,12 @@ TEST_CASE(im2col_3x3_with_padding_test) ...@@ -200,12 +200,12 @@ TEST_CASE(im2col_3x3_with_padding_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int> correct = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, std::vector<int> gold = {0, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0,
0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0}; 0, 0, 1, 0, 2, 3, 0, 0, 0, 0, 1, 0, 2, 3, 0, 0, 0, 0};
std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1; std::size_t col_height = (size[0] - f[0] + 2 * padding[0]) / stride[0] + 1;
std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1; std::size_t col_width = (size[1] - f[1] + 2 * padding[1]) / stride[1] + 1;
std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width); std::vector<float> results_vector(channels * f[0] * f[1] * col_height * col_width);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,8 +45,8 @@ TEST_CASE(isnan_test) ...@@ -45,8 +45,8 @@ TEST_CASE(isnan_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 1, 1, 0, 0}; std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
// half test // half test
...@@ -64,8 +64,8 @@ TEST_CASE(isnan_test) ...@@ -64,8 +64,8 @@ TEST_CASE(isnan_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 1, 1, 0, 0}; std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -86,6 +86,6 @@ TEST_CASE(isnan_dyn_test) ...@@ -86,6 +86,6 @@ TEST_CASE(isnan_dyn_test)
auto result = p.eval(params0).back(); auto result = p.eval(params0).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 1, 1, 0, 0}; std::vector<float> gold = {0, 0, 1, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -42,5 +42,5 @@ TEST_CASE(leaky_relu_test) ...@@ -42,5 +42,5 @@ TEST_CASE(leaky_relu_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.01f, 0.f, 1.f}; std::vector<float> gold = {-0.01f, 0.f, 1.f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -45,7 +45,7 @@ TEST_CASE(log_test) ...@@ -45,7 +45,7 @@ TEST_CASE(log_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(log_dyn_test) TEST_CASE(log_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(log_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(log_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return logf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -49,7 +49,7 @@ TEST_CASE(logical_and_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_and_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 and n2; return n1 and n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_and_dyn_test) TEST_CASE(logical_and_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_and_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_and_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 and n2; }); [](bool n1, bool n2) -> bool { return n1 and n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/program.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 or n2; return n1 or n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_or_dyn_test) TEST_CASE(logical_or_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 or n2; }); [](bool n1, bool n2) -> bool { return n1 or n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment