Unverified Commit 69d8d789 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add options to set tolerances inside MIGraphX driver (#2213)

MIGraphX verification by default uses normalized RMS error as the basis for the verification.  This change adds some logic to allow migraphx to do "np.allclose" type of elementwise verification using atol and rtol.

Commit also includes changes to consistently pass "gold" or "expected" results as the second argument for "verify_range()" calls.  Default RMS tolerance inside driver is set to 0.001 which IMO is high for FP32 compared to what we had earlier. Need better defaults
parent e12032fb
...@@ -64,7 +64,7 @@ TEST_CASE(scatter_ax0_test) ...@@ -64,7 +64,7 @@ TEST_CASE(scatter_ax0_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -78,7 +78,7 @@ TEST_CASE(scatter_ax_neg_test) ...@@ -78,7 +78,7 @@ TEST_CASE(scatter_ax_neg_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -91,7 +91,7 @@ TEST_CASE(scatter_ax1_test) ...@@ -91,7 +91,7 @@ TEST_CASE(scatter_ax1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0}; std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -128,7 +128,7 @@ TEST_CASE(scatter_reduction1_test) ...@@ -128,7 +128,7 @@ TEST_CASE(scatter_reduction1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0}; std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_none)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none));
} }
} }
...@@ -142,7 +142,7 @@ TEST_CASE(scatter_reduction2_test) ...@@ -142,7 +142,7 @@ TEST_CASE(scatter_reduction2_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0}; std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_mul)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul));
} }
} }
TEST_CASE(scatter_reduction3_test) TEST_CASE(scatter_reduction3_test)
...@@ -155,7 +155,7 @@ TEST_CASE(scatter_reduction3_test) ...@@ -155,7 +155,7 @@ TEST_CASE(scatter_reduction3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0}; std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_add)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_add));
} }
} }
...@@ -184,7 +184,7 @@ TEST_CASE(scatter_reduction_3x3_test) ...@@ -184,7 +184,7 @@ TEST_CASE(scatter_reduction_3x3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0}; std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_a2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a2));
} }
} }
...@@ -221,7 +221,7 @@ TEST_CASE(scatter_reduction_3x3_xpose1_test) ...@@ -221,7 +221,7 @@ TEST_CASE(scatter_reduction_3x3_xpose1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0}; std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_none2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none2));
} }
} }
...@@ -236,7 +236,7 @@ TEST_CASE(scatter_reduction_3x3_xpose2_test) ...@@ -236,7 +236,7 @@ TEST_CASE(scatter_reduction_3x3_xpose2_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0}; std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_a3)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a3));
} }
} }
...@@ -250,6 +250,6 @@ TEST_CASE(scatter_reduction_3x3_xpose3_test) ...@@ -250,6 +250,6 @@ TEST_CASE(scatter_reduction_3x3_xpose3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0}; std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_mul2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul2));
} }
} }
...@@ -57,7 +57,7 @@ TEST_CASE(scatternd_add_reduction_test) ...@@ -57,7 +57,7 @@ TEST_CASE(scatternd_add_reduction_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9}; std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_reduction_dyn_test) TEST_CASE(scatternd_reduction_dyn_test)
...@@ -102,5 +102,5 @@ TEST_CASE(scatternd_reduction_dyn_test) ...@@ -102,5 +102,5 @@ TEST_CASE(scatternd_reduction_dyn_test)
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12, 9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -57,5 +57,5 @@ TEST_CASE(scatternd_mul_reduction_test) ...@@ -57,5 +57,5 @@ TEST_CASE(scatternd_mul_reduction_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96}; std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -56,7 +56,7 @@ TEST_CASE(scatternd_shapes_test_1) ...@@ -56,7 +56,7 @@ TEST_CASE(scatternd_shapes_test_1)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12}; std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_shapes_test_2) TEST_CASE(scatternd_shapes_test_2)
...@@ -86,7 +86,7 @@ TEST_CASE(scatternd_shapes_test_2) ...@@ -86,7 +86,7 @@ TEST_CASE(scatternd_shapes_test_2)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 2, 4}; std::vector<float> gold{5, 6, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_shapes_test_3) TEST_CASE(scatternd_shapes_test_3)
...@@ -117,7 +117,7 @@ TEST_CASE(scatternd_shapes_test_3) ...@@ -117,7 +117,7 @@ TEST_CASE(scatternd_shapes_test_3)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10}; std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_1) TEST_CASE(scatternd_test_1)
...@@ -147,7 +147,7 @@ TEST_CASE(scatternd_test_1) ...@@ -147,7 +147,7 @@ TEST_CASE(scatternd_test_1)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12}; std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_2) TEST_CASE(scatternd_test_2)
...@@ -177,7 +177,7 @@ TEST_CASE(scatternd_test_2) ...@@ -177,7 +177,7 @@ TEST_CASE(scatternd_test_2)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 3, 4}; std::vector<float> gold{5, 6, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_3) TEST_CASE(scatternd_test_3)
...@@ -207,7 +207,7 @@ TEST_CASE(scatternd_test_3) ...@@ -207,7 +207,7 @@ TEST_CASE(scatternd_test_3)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10}; std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_4) TEST_CASE(scatternd_test_4)
...@@ -242,7 +242,7 @@ TEST_CASE(scatternd_test_4) ...@@ -242,7 +242,7 @@ TEST_CASE(scatternd_test_4)
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; 4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_5) TEST_CASE(scatternd_test_5)
...@@ -273,5 +273,5 @@ TEST_CASE(scatternd_test_5) ...@@ -273,5 +273,5 @@ TEST_CASE(scatternd_test_5)
std::vector<float> gold(32, 0); std::vector<float> gold(32, 0);
std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin()); std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin());
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -74,7 +74,7 @@ TEST_CASE(select_module_add_test) ...@@ -74,7 +74,7 @@ TEST_CASE(select_module_add_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 14, 5, 10, 5, 14, 14, 2}; std::vector<float> gold{2, 14, 5, 10, 5, 14, 14, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_reduce_test0) TEST_CASE(select_module_reduce_test0)
...@@ -120,7 +120,7 @@ TEST_CASE(select_module_reduce_test0) ...@@ -120,7 +120,7 @@ TEST_CASE(select_module_reduce_test0)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4}; std::vector<float> gold{-5, 12, 7, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_reduce_test1) TEST_CASE(select_module_reduce_test1)
...@@ -166,7 +166,7 @@ TEST_CASE(select_module_reduce_test1) ...@@ -166,7 +166,7 @@ TEST_CASE(select_module_reduce_test1)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4}; std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_not_found_error) TEST_CASE(select_module_not_found_error)
......
...@@ -44,7 +44,7 @@ TEST_CASE(sigmoid_test) ...@@ -44,7 +44,7 @@ TEST_CASE(sigmoid_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sigmoid_dyn_test) TEST_CASE(sigmoid_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(sigmoid_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(sigmoid_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sign_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sign_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sign_dyn_test) TEST_CASE(sign_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sin_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sin_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sin_dyn_test) TEST_CASE(sin_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(sin_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(sin_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sinh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sinh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sinh_dynamic_test) TEST_CASE(sinh_dynamic_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sinh_dynamic_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sinh_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -48,7 +48,7 @@ TEST_CASE(slice_test_1) ...@@ -48,7 +48,7 @@ TEST_CASE(slice_test_1)
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -72,7 +72,7 @@ TEST_CASE(slice_test_2) ...@@ -72,7 +72,7 @@ TEST_CASE(slice_test_2)
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -99,7 +99,7 @@ TEST_CASE(slice_var_inputs_static0) ...@@ -99,7 +99,7 @@ TEST_CASE(slice_var_inputs_static0)
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2); std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_static1) TEST_CASE(slice_var_inputs_static1)
...@@ -125,7 +125,7 @@ TEST_CASE(slice_var_inputs_static1) ...@@ -125,7 +125,7 @@ TEST_CASE(slice_var_inputs_static1)
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2); std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_static2) TEST_CASE(slice_var_inputs_static2)
...@@ -154,7 +154,7 @@ TEST_CASE(slice_var_inputs_static2) ...@@ -154,7 +154,7 @@ TEST_CASE(slice_var_inputs_static2)
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2); std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_var_inputs_dyn) TEST_CASE(slice_var_inputs_dyn)
...@@ -182,7 +182,7 @@ TEST_CASE(slice_var_inputs_dyn) ...@@ -182,7 +182,7 @@ TEST_CASE(slice_var_inputs_dyn)
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11}; std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(slice_dyn_test0) TEST_CASE(slice_dyn_test0)
...@@ -213,7 +213,7 @@ TEST_CASE(slice_dyn_test0) ...@@ -213,7 +213,7 @@ TEST_CASE(slice_dyn_test0)
std::vector<int> results_vector(2 * 1 * 2); std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -244,6 +244,6 @@ TEST_CASE(slice_dyn_test1) ...@@ -244,6 +244,6 @@ TEST_CASE(slice_dyn_test1)
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10}; std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2); std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(result.get_shape() == sresult); EXPECT(result.get_shape() == sresult);
} }
...@@ -35,7 +35,7 @@ TEST_CASE(softmax_simple_test) ...@@ -35,7 +35,7 @@ TEST_CASE(softmax_simple_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> a = {0.25, 0.75}; std::vector<float> a = {0.25, 0.75};
std::vector<float> s = {0.377541, 0.622459}; std::vector<float> gold = {0.377541, 0.622459};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al); mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
...@@ -43,7 +43,7 @@ TEST_CASE(softmax_simple_test) ...@@ -43,7 +43,7 @@ TEST_CASE(softmax_simple_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(2); std::vector<float> results_vector(2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(softmax_test) TEST_CASE(softmax_test)
...@@ -76,7 +76,7 @@ TEST_CASE(softmax_test) ...@@ -76,7 +76,7 @@ TEST_CASE(softmax_test)
2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01, 2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01,
-6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01}; -6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
std::vector<float> s = { std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
...@@ -103,7 +103,7 @@ TEST_CASE(softmax_test) ...@@ -103,7 +103,7 @@ TEST_CASE(softmax_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(softmax_dyn_test) TEST_CASE(softmax_dyn_test)
...@@ -147,7 +147,7 @@ TEST_CASE(softmax_dyn_test) ...@@ -147,7 +147,7 @@ TEST_CASE(softmax_dyn_test)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector(120); std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> s = { std::vector<float> gold = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741, 0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758, 0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111, 0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
...@@ -166,5 +166,5 @@ TEST_CASE(softmax_dyn_test) ...@@ -166,5 +166,5 @@ TEST_CASE(softmax_dyn_test)
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739, 0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723, 0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796}; 0.42914796};
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sqdiff_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sqdiff_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4}; std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sqdiff_dyn_test) TEST_CASE(sqdiff_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sqdiff_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sqdiff_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {4, 4, 4}; std::vector<float> gold = {4, 4, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(sqrt_test) ...@@ -45,7 +45,7 @@ TEST_CASE(sqrt_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sqrt_dynamic_test) TEST_CASE(sqrt_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(sqrt_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(sqrt_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return sqrtf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sub_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sub_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sub_dyn_test) TEST_CASE(sub_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(sub_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(sub_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2, -2, -2}; std::vector<float> gold = {-2, -2, -2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(tan_test) ...@@ -45,7 +45,7 @@ TEST_CASE(tan_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(tan_dynamic_test) TEST_CASE(tan_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(tan_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(tan_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(tanh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(tanh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(tanh_dynamic_test) TEST_CASE(tanh_dynamic_test)
...@@ -68,5 +68,5 @@ TEST_CASE(tanh_dynamic_test) ...@@ -68,5 +68,5 @@ TEST_CASE(tanh_dynamic_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return tanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -59,7 +59,7 @@ TEST_CASE(transpose_test) ...@@ -59,7 +59,7 @@ TEST_CASE(transpose_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -86,5 +86,5 @@ TEST_CASE(transpose_dyn_test) ...@@ -86,5 +86,5 @@ TEST_CASE(transpose_dyn_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -54,7 +54,7 @@ TEST_CASE(where_test) ...@@ -54,7 +54,7 @@ TEST_CASE(where_test)
for(int i = 0; i < gold.size(); ++i) for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i]; gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold));
} }
TEST_CASE(where_dyn_test) TEST_CASE(where_dyn_test)
...@@ -85,7 +85,7 @@ TEST_CASE(where_dyn_test) ...@@ -85,7 +85,7 @@ TEST_CASE(where_dyn_test)
std::vector<float> results_vector(3 * 3); std::vector<float> results_vector(3 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1}; std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(where_broadcasted_inputs_test) TEST_CASE(where_broadcasted_inputs_test)
...@@ -113,5 +113,5 @@ TEST_CASE(where_broadcasted_inputs_test) ...@@ -113,5 +113,5 @@ TEST_CASE(where_broadcasted_inputs_test)
for(int i = 0; i < gold.size(); ++i) for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i]; gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold));
} }
...@@ -140,8 +140,17 @@ TEST_CASE(handling_tensors) ...@@ -140,8 +140,17 @@ TEST_CASE(handling_tensors)
-0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193, -0.06269585, 0.18658121, -0.03944227, 0.0111798, -0.17731084, 0.11789055, -0.09982193,
0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292}; 0.08142821, 0.0729029, 0.11303909, 0.12735154, 0.03885292};
// Create the arguments in a parameter_map
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
// Evaluate and confirm the result
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
// Solution vector // Solution vector
std::vector<float> sol = {-0.20817225, std::vector<float> gold = {-0.20817225,
0.87965256, 0.87965256,
0.14958936, 0.14958936,
-1.24887264, -1.24887264,
...@@ -158,17 +167,7 @@ TEST_CASE(handling_tensors) ...@@ -158,17 +167,7 @@ TEST_CASE(handling_tensors)
-0.16138598, -0.16138598,
0.79344082}; 0.79344082};
// Create the arguments in a parameter_map EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
migraphx::parameter_map params;
params["X"] = migraphx::argument(input_shape, a.data());
params["W"] = migraphx::argument(weights_shape, c.data());
// Evaluate and confirm the result
auto result = p.eval(params).back();
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, sol));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -197,7 +197,7 @@ TEST_CASE(literal_rewrite_pooling_test) ...@@ -197,7 +197,7 @@ TEST_CASE(literal_rewrite_pooling_test)
auto result1 = p1.eval({}).back(); auto result1 = p1.eval({}).back();
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
visit_all(result1, result2)( visit_all(result1, result2)(
[&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_range(r1, r2)); }); [&](auto r1, auto r2) { EXPECT(migraphx::verify::verify_rms_range(r1, r2)); });
}; };
test_rewrite_pooling(migraphx::op::pooling_mode::max, test_rewrite_pooling(migraphx::op::pooling_mode::max,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment