Unverified Commit 69d8d789 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add options to set tolerances inside MIGraphX driver (#2213)

MIGraphX verification by default uses normalized RMS error as the basis for the verification.  This change adds some logic to allow migraphx to do "np.allclose" type of elementwise verification using atol and rtol.

Commit also includes changes to consistently pass "gold" or "expected" results as the second argument for "verify_range()" calls.  Default RMS tolerance inside driver is set to 0.001 which IMO is high for FP32 compared to what we had earlier. Need better defaults
parent e12032fb
...@@ -64,7 +64,7 @@ TEST_CASE(scatter_ax0_test) ...@@ -64,7 +64,7 @@ TEST_CASE(scatter_ax0_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -78,7 +78,7 @@ TEST_CASE(scatter_ax_neg_test) ...@@ -78,7 +78,7 @@ TEST_CASE(scatter_ax_neg_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2}; std::vector<float> gold = {2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -91,7 +91,7 @@ TEST_CASE(scatter_ax1_test) ...@@ -91,7 +91,7 @@ TEST_CASE(scatter_ax1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0}; std::vector<float> gold = {1.1, 1.0, 1.2, 2.0, 2.2, 2.1, 0.0, 0.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -128,7 +128,7 @@ TEST_CASE(scatter_reduction1_test) ...@@ -128,7 +128,7 @@ TEST_CASE(scatter_reduction1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0}; std::vector<float> gold_none = {1.0, 1.1, 3.0, 2.1, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_none)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none));
} }
} }
...@@ -142,7 +142,7 @@ TEST_CASE(scatter_reduction2_test) ...@@ -142,7 +142,7 @@ TEST_CASE(scatter_reduction2_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0}; std::vector<float> gold_mul = {1.0, 2.2, 3.0, 8.4, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_mul)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul));
} }
} }
TEST_CASE(scatter_reduction3_test) TEST_CASE(scatter_reduction3_test)
...@@ -155,7 +155,7 @@ TEST_CASE(scatter_reduction3_test) ...@@ -155,7 +155,7 @@ TEST_CASE(scatter_reduction3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0}; std::vector<float> gold_add = {1.0, 3.1, 3.0, 6.1, 5.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_add)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_add));
} }
} }
...@@ -184,7 +184,7 @@ TEST_CASE(scatter_reduction_3x3_test) ...@@ -184,7 +184,7 @@ TEST_CASE(scatter_reduction_3x3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0}; std::vector<float> gold_a2 = {4.1, 4.0, 4.2, 10.0, 10.2, 10.1, 3.0, 3.0, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_a2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a2));
} }
} }
...@@ -221,7 +221,7 @@ TEST_CASE(scatter_reduction_3x3_xpose1_test) ...@@ -221,7 +221,7 @@ TEST_CASE(scatter_reduction_3x3_xpose1_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0}; std::vector<float> gold_none2 = {1.1, 7.0, 3.0, 1.0, 7.2, 3.0, 1.2, 7.1, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_none2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_none2));
} }
} }
...@@ -236,7 +236,7 @@ TEST_CASE(scatter_reduction_3x3_xpose2_test) ...@@ -236,7 +236,7 @@ TEST_CASE(scatter_reduction_3x3_xpose2_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0}; std::vector<float> gold_a3 = {4.1, 10.0, 3.0, 4.0, 10.2, 3.0, 4.2, 10.1, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_a3)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_a3));
} }
} }
...@@ -250,6 +250,6 @@ TEST_CASE(scatter_reduction_3x3_xpose3_test) ...@@ -250,6 +250,6 @@ TEST_CASE(scatter_reduction_3x3_xpose3_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0}; std::vector<float> gold_mul2 = {3.3, 21.0, 3.0, 3.0, 21.6, 3.0, 3.6, 21.3, 3.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold_mul2)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_mul2));
} }
} }
...@@ -57,7 +57,7 @@ TEST_CASE(scatternd_add_reduction_test) ...@@ -57,7 +57,7 @@ TEST_CASE(scatternd_add_reduction_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9}; std::vector<float> gold{1, 3, 3, 5, 6, 6, 7, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_reduction_dyn_test) TEST_CASE(scatternd_reduction_dyn_test)
...@@ -102,5 +102,5 @@ TEST_CASE(scatternd_reduction_dyn_test) ...@@ -102,5 +102,5 @@ TEST_CASE(scatternd_reduction_dyn_test)
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12, 9, 8, 7, 6, 6, 5, 4, 3, 4, 5, 6, 7, 9, 10, 11, 12,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -57,5 +57,5 @@ TEST_CASE(scatternd_mul_reduction_test) ...@@ -57,5 +57,5 @@ TEST_CASE(scatternd_mul_reduction_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96}; std::vector<float> gold{1, 22, 3, 40, 45, 6, 7, 96};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -56,7 +56,7 @@ TEST_CASE(scatternd_shapes_test_1) ...@@ -56,7 +56,7 @@ TEST_CASE(scatternd_shapes_test_1)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12}; std::vector<float> gold{0, 11, 0, 10, 9, 0, 0, 12};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_shapes_test_2) TEST_CASE(scatternd_shapes_test_2)
...@@ -86,7 +86,7 @@ TEST_CASE(scatternd_shapes_test_2) ...@@ -86,7 +86,7 @@ TEST_CASE(scatternd_shapes_test_2)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 2, 4}; std::vector<float> gold{5, 6, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_shapes_test_3) TEST_CASE(scatternd_shapes_test_3)
...@@ -117,7 +117,7 @@ TEST_CASE(scatternd_shapes_test_3) ...@@ -117,7 +117,7 @@ TEST_CASE(scatternd_shapes_test_3)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10}; std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_1) TEST_CASE(scatternd_test_1)
...@@ -147,7 +147,7 @@ TEST_CASE(scatternd_test_1) ...@@ -147,7 +147,7 @@ TEST_CASE(scatternd_test_1)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12}; std::vector<float> gold{1, 11, 3, 10, 9, 6, 7, 12};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_2) TEST_CASE(scatternd_test_2)
...@@ -177,7 +177,7 @@ TEST_CASE(scatternd_test_2) ...@@ -177,7 +177,7 @@ TEST_CASE(scatternd_test_2)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5, 6, 3, 4}; std::vector<float> gold{5, 6, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_3) TEST_CASE(scatternd_test_3)
...@@ -207,7 +207,7 @@ TEST_CASE(scatternd_test_3) ...@@ -207,7 +207,7 @@ TEST_CASE(scatternd_test_3)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10}; std::vector<float> gold{9, 2, 3, 4, 5, 6, 7, 10};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_4) TEST_CASE(scatternd_test_4)
...@@ -242,7 +242,7 @@ TEST_CASE(scatternd_test_4) ...@@ -242,7 +242,7 @@ TEST_CASE(scatternd_test_4)
7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8}; 4, 4, 4, 4, 8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(scatternd_test_5) TEST_CASE(scatternd_test_5)
...@@ -273,5 +273,5 @@ TEST_CASE(scatternd_test_5) ...@@ -273,5 +273,5 @@ TEST_CASE(scatternd_test_5)
std::vector<float> gold(32, 0); std::vector<float> gold(32, 0);
std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin()); std::copy(data_vec.begin(), data_vec.begin() + 16, gold.begin());
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -74,7 +74,7 @@ TEST_CASE(select_module_add_test) ...@@ -74,7 +74,7 @@ TEST_CASE(select_module_add_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 14, 5, 10, 5, 14, 14, 2}; std::vector<float> gold{2, 14, 5, 10, 5, 14, 14, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_reduce_test0) TEST_CASE(select_module_reduce_test0)
...@@ -120,7 +120,7 @@ TEST_CASE(select_module_reduce_test0) ...@@ -120,7 +120,7 @@ TEST_CASE(select_module_reduce_test0)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4}; std::vector<float> gold{-5, 12, 7, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_reduce_test1) TEST_CASE(select_module_reduce_test1)
...@@ -166,7 +166,7 @@ TEST_CASE(select_module_reduce_test1) ...@@ -166,7 +166,7 @@ TEST_CASE(select_module_reduce_test1)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4}; std::vector<float> gold{-5, 12, 7, 4, -5, 12, 7, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(select_module_not_found_error) TEST_CASE(select_module_not_found_error)
......
...@@ -44,7 +44,7 @@ TEST_CASE(sigmoid_test) ...@@ -44,7 +44,7 @@ TEST_CASE(sigmoid_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sigmoid_dyn_test) TEST_CASE(sigmoid_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(sigmoid_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(sigmoid_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)}; std::vector<float> gold{sigmoid(-1), sigmoid(2), sigmoid(-3), sigmoid(4)};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(sign_test) ...@@ -43,7 +43,7 @@ TEST_CASE(sign_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(sign_dyn_test) TEST_CASE(sign_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(sign_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, -1.0, -1.0, 0.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment