"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "873f6c0c1f7ff9d12880a8e110e426577e1a2ca9"
Unverified Commit 69d8d789 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add options to set tolerances inside MIGraphX driver (#2213)

MIGraphX verification by default uses normalized RMS error as the basis for the verification.  This change adds some logic to allow migraphx to do "np.allclose" type of elementwise verification using atol and rtol.

Commit also includes changes to consistently pass "gold" or "expected" results as the second argument for "verify_range()" calls.  Default RMS tolerance inside driver is set to 0.001 which IMO is high for FP32 compared to what we had earlier. Need better defaults
parent e12032fb
...@@ -62,7 +62,7 @@ TEST_CASE(nms_dyn_out_test) ...@@ -62,7 +62,7 @@ TEST_CASE(nms_dyn_out_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_batch_test) TEST_CASE(nms_dyn_batch_test)
...@@ -108,7 +108,7 @@ TEST_CASE(nms_dyn_batch_test) ...@@ -108,7 +108,7 @@ TEST_CASE(nms_dyn_batch_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 1, 0, 3, 1, 0, 0, 1, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 1, 0, 3, 1, 0, 0, 1, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_boxes_test) TEST_CASE(nms_dyn_boxes_test)
...@@ -151,7 +151,7 @@ TEST_CASE(nms_dyn_boxes_test) ...@@ -151,7 +151,7 @@ TEST_CASE(nms_dyn_boxes_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_classes_test) TEST_CASE(nms_dyn_classes_test)
...@@ -195,7 +195,7 @@ TEST_CASE(nms_dyn_classes_test) ...@@ -195,7 +195,7 @@ TEST_CASE(nms_dyn_classes_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 1, 3, 0, 1, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 1, 3, 0, 1, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_not_center_test) TEST_CASE(nms_not_center_test)
...@@ -231,7 +231,7 @@ TEST_CASE(nms_not_center_test) ...@@ -231,7 +231,7 @@ TEST_CASE(nms_not_center_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_test) TEST_CASE(nms_test)
...@@ -265,7 +265,7 @@ TEST_CASE(nms_test) ...@@ -265,7 +265,7 @@ TEST_CASE(nms_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_transpose1_test) TEST_CASE(nms_transpose1_test)
...@@ -303,7 +303,7 @@ TEST_CASE(nms_transpose1_test) ...@@ -303,7 +303,7 @@ TEST_CASE(nms_transpose1_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_transpose2_test) TEST_CASE(nms_transpose2_test)
...@@ -341,5 +341,5 @@ TEST_CASE(nms_transpose2_test) ...@@ -341,5 +341,5 @@ TEST_CASE(nms_transpose2_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test) ...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0}; 1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32) ...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_test_bool) TEST_CASE(not_test_bool)
...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool) ...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(data.size()); std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; }); std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_dyn_test) TEST_CASE(not_dyn_test)
...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test) ...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -42,7 +42,7 @@ TEST_CASE(pad_test) ...@@ -42,7 +42,7 @@ TEST_CASE(pad_test)
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_asym) TEST_CASE(pad_test_asym)
...@@ -57,7 +57,7 @@ TEST_CASE(pad_test_asym) ...@@ -57,7 +57,7 @@ TEST_CASE(pad_test_asym)
std::vector<float> results_vector(9); std::vector<float> results_vector(9);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0}; std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_highest_half) TEST_CASE(pad_test_highest_half)
...@@ -76,7 +76,7 @@ TEST_CASE(pad_test_highest_half) ...@@ -76,7 +76,7 @@ TEST_CASE(pad_test_highest_half)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::max(); const float x = std::numeric_limits<migraphx::half>::max();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_lowest_half) TEST_CASE(pad_test_lowest_half)
...@@ -95,7 +95,7 @@ TEST_CASE(pad_test_lowest_half) ...@@ -95,7 +95,7 @@ TEST_CASE(pad_test_lowest_half)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::lowest(); const float x = std::numeric_limits<migraphx::half>::lowest();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_dyn_test) TEST_CASE(pad_dyn_test)
...@@ -115,5 +115,5 @@ TEST_CASE(pad_dyn_test) ...@@ -115,5 +115,5 @@ TEST_CASE(pad_dyn_test)
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test) ...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -51,7 +51,7 @@ TEST_CASE(avgpool_rank3_test) ...@@ -51,7 +51,7 @@ TEST_CASE(avgpool_rank3_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35}; std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_test) TEST_CASE(avgpool_dyn_test)
...@@ -77,7 +77,7 @@ TEST_CASE(avgpool_dyn_test) ...@@ -77,7 +77,7 @@ TEST_CASE(avgpool_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35}; std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_pad_test) TEST_CASE(avgpool_dyn_pad_test)
...@@ -105,7 +105,7 @@ TEST_CASE(avgpool_dyn_pad_test) ...@@ -105,7 +105,7 @@ TEST_CASE(avgpool_dyn_pad_test)
std::vector<float> gold{ std::vector<float> gold{
0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6}; 0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_auto_pad_test) TEST_CASE(avgpool_dyn_auto_pad_test)
...@@ -141,7 +141,7 @@ TEST_CASE(avgpool_dyn_auto_pad_test) ...@@ -141,7 +141,7 @@ TEST_CASE(avgpool_dyn_auto_pad_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5, 2.5, 3.5, 3.5}; std::vector<float> gold{2.5, 2.5, 3.5, 3.5};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_auto_pad_1d_test) TEST_CASE(avgpool_dyn_auto_pad_1d_test)
...@@ -175,7 +175,7 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test) ...@@ -175,7 +175,7 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test)
0.8, 0.65, 0.7, 0.5, 0.8, 0.65, 0.7, 0.5,
0.1, 0.4, 0.4, 0.35}; 0.1, 0.4, 0.4, 0.35};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_pad_ceil_test) TEST_CASE(avgpool_dyn_pad_ceil_test)
...@@ -212,7 +212,7 @@ TEST_CASE(avgpool_dyn_pad_ceil_test) ...@@ -212,7 +212,7 @@ TEST_CASE(avgpool_dyn_pad_ceil_test)
2.0, 2.5, 2.5, 3.0, 2.0, 2.5, 2.5, 3.0,
3.0, 3.5, 3.5, 4.0}; 3.0, 3.5, 3.5, 4.0};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_rank3_stride2_test) TEST_CASE(avgpool_rank3_stride2_test)
...@@ -245,7 +245,7 @@ TEST_CASE(avgpool_rank3_stride2_test) ...@@ -245,7 +245,7 @@ TEST_CASE(avgpool_rank3_stride2_test)
-0.3442, 1.22005, 1.5295, -0.3442, 1.22005, 1.5295,
0.9965, 0.7854, -0.2915}; 0.9965, 0.7854, -0.2915};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_rank5_test) TEST_CASE(avgpool_rank5_test)
...@@ -281,7 +281,7 @@ TEST_CASE(avgpool_rank5_test) ...@@ -281,7 +281,7 @@ TEST_CASE(avgpool_rank5_test)
-0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375, -0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375,
-0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187, -0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187,
-0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625}; -0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
...@@ -302,7 +302,7 @@ TEST_CASE(globalavgpool_test) ...@@ -302,7 +302,7 @@ TEST_CASE(globalavgpool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375}; std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalavgpool_dyn_test) TEST_CASE(globalavgpool_dyn_test)
...@@ -325,7 +325,7 @@ TEST_CASE(globalavgpool_dyn_test) ...@@ -325,7 +325,7 @@ TEST_CASE(globalavgpool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375}; std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globallppool_test) TEST_CASE(globallppool_test)
...@@ -347,7 +347,7 @@ TEST_CASE(globallppool_test) ...@@ -347,7 +347,7 @@ TEST_CASE(globallppool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815}; std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globallppool_dyn_test) TEST_CASE(globallppool_dyn_test)
...@@ -371,7 +371,7 @@ TEST_CASE(globallppool_dyn_test) ...@@ -371,7 +371,7 @@ TEST_CASE(globallppool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815}; std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalmaxpool_test) TEST_CASE(globalmaxpool_test)
...@@ -392,7 +392,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -392,7 +392,7 @@ TEST_CASE(globalmaxpool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7}; std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalmaxpool_dyn_test) TEST_CASE(globalmaxpool_dyn_test)
...@@ -416,7 +416,7 @@ TEST_CASE(globalmaxpool_dyn_test) ...@@ -416,7 +416,7 @@ TEST_CASE(globalmaxpool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7}; std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(lppool_l1_norm_test) TEST_CASE(lppool_l1_norm_test)
...@@ -440,7 +440,7 @@ TEST_CASE(lppool_l1_norm_test) ...@@ -440,7 +440,7 @@ TEST_CASE(lppool_l1_norm_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7}; std::vector<float> gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
// TODO: this tests compliance with a oneDNN rule and a feature that's commented out // TODO: this tests compliance with a oneDNN rule and a feature that's commented out
...@@ -493,7 +493,7 @@ TEST_CASE(lppool_l2_norm_test) ...@@ -493,7 +493,7 @@ TEST_CASE(lppool_l2_norm_test)
0.7071067811865475, 0.7071067811865475,
0.7071067811865475, 0.7071067811865475,
0.6082762530298219}; 0.6082762530298219};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(lppool_dyn_test) TEST_CASE(lppool_dyn_test)
...@@ -526,7 +526,7 @@ TEST_CASE(lppool_dyn_test) ...@@ -526,7 +526,7 @@ TEST_CASE(lppool_dyn_test)
0.7071067811865475, 0.7071067811865475,
0.7071067811865475, 0.7071067811865475,
0.6082762530298219}; 0.6082762530298219};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
...@@ -565,12 +565,6 @@ TEST_CASE(maxpool_test) ...@@ -565,12 +565,6 @@ TEST_CASE(maxpool_test)
-0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746, -0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746,
-0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223, -0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223,
-0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682}; -0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682};
std::vector<float> c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -583,7 +577,15 @@ TEST_CASE(maxpool_test) ...@@ -583,7 +577,15 @@ TEST_CASE(maxpool_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, c)); std::vector<float> gold = {
1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_pad_test) TEST_CASE(maxpool_pad_test)
...@@ -591,7 +593,6 @@ TEST_CASE(maxpool_pad_test) ...@@ -591,7 +593,6 @@ TEST_CASE(maxpool_pad_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> a = {-6, -5, -4, -3, -5, -1, 0, 1, 2, 3, 4, 5}; std::vector<float> a = {-6, -5, -4, -3, -5, -1, 0, 1, 2, 3, 4, 5};
std::vector<float> c = {-4, -3, -4, -1, 2, 3, 4, 5};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -611,8 +612,8 @@ TEST_CASE(maxpool_pad_test) ...@@ -611,8 +612,8 @@ TEST_CASE(maxpool_pad_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(8); std::vector<float> results_vector(8);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-4, -3, -4, -1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_range(results_vector, c)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_test0) TEST_CASE(maxpool_rank3_test0)
...@@ -635,7 +636,7 @@ TEST_CASE(maxpool_rank3_test0) ...@@ -635,7 +636,7 @@ TEST_CASE(maxpool_rank3_test0)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6}; std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_test1) TEST_CASE(maxpool_rank3_test1)
...@@ -660,7 +661,7 @@ TEST_CASE(maxpool_rank3_test1) ...@@ -660,7 +661,7 @@ TEST_CASE(maxpool_rank3_test1)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603}; std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_ceil_test) TEST_CASE(maxpool_rank3_ceil_test)
...@@ -694,7 +695,7 @@ TEST_CASE(maxpool_rank3_ceil_test) ...@@ -694,7 +695,7 @@ TEST_CASE(maxpool_rank3_ceil_test)
0.6022, 1.1925, 0.5493, -0.8039, 0.6022, 1.1925, 0.5493, -0.8039,
0.9907, 1.5001, -1.1603, 1.2556}; 0.9907, 1.5001, -1.1603, 1.2556};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank5_test) TEST_CASE(maxpool_rank5_test)
...@@ -727,7 +728,7 @@ TEST_CASE(maxpool_rank5_test) ...@@ -727,7 +728,7 @@ TEST_CASE(maxpool_rank5_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859}; std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_dyn_test) TEST_CASE(maxpool_dyn_test)
...@@ -752,5 +753,5 @@ TEST_CASE(maxpool_dyn_test) ...@@ -752,5 +753,5 @@ TEST_CASE(maxpool_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6}; std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -46,7 +46,7 @@ TEST_CASE(pow_test) ...@@ -46,7 +46,7 @@ TEST_CASE(pow_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pow_dyn_test) TEST_CASE(pow_dyn_test)
...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test) ...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(prelu_test) ...@@ -43,7 +43,7 @@ TEST_CASE(prelu_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f}; std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(prelu_dyn_test) TEST_CASE(prelu_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(prelu_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(prelu_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f}; std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -47,25 +47,25 @@ TEST_CASE(quant_conv2d_padding_stride_test) ...@@ -47,25 +47,25 @@ TEST_CASE(quant_conv2d_padding_stride_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = {4521, std::vector<int32_t> gold = {4521,
7014, 7014,
7830, 7830,
11952, 11952,
10515, 10515,
16734, 16734,
19737, 19737,
30906, 30906,
13161, 13161,
19542, 19542,
19494, 19494,
28800, 28800,
34707, 34707,
52590, 52590,
54729, 54729,
82746}; 82746};
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(quant_conv2d_padding_test) TEST_CASE(quant_conv2d_padding_test)
...@@ -83,8 +83,8 @@ TEST_CASE(quant_conv2d_padding_test) ...@@ -83,8 +83,8 @@ TEST_CASE(quant_conv2d_padding_test)
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl); migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = { std::vector<int32_t> gold = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007, 4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826, 7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396, 30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
...@@ -93,7 +93,7 @@ TEST_CASE(quant_conv2d_padding_test) ...@@ -93,7 +93,7 @@ TEST_CASE(quant_conv2d_padding_test)
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(quant_conv2d_test) TEST_CASE(quant_conv2d_test)
...@@ -114,24 +114,24 @@ TEST_CASE(quant_conv2d_test) ...@@ -114,24 +114,24 @@ TEST_CASE(quant_conv2d_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = {10197, std::vector<int32_t> gold = {10197,
10548, 10548,
11601, 11601,
11952, 11952,
25506, 25506,
26586, 26586,
29826, 29826,
30906, 30906,
27045, 27045,
27396, 27396,
28449, 28449,
28800, 28800,
77346, 77346,
78426, 78426,
81666, 81666,
82746}; 82746};
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -68,7 +68,9 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,9 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100)); EXPECT(migraphx::verify::verify_range_with_tolerance(result_vec,
migraphx::verify::expected{rand_samples},
migraphx::verify::tolerance{0.00001}));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
...@@ -102,9 +104,9 @@ TEST_CASE(random_uniform_int_test) ...@@ -102,9 +104,9 @@ TEST_CASE(random_uniform_int_test)
// Compare result with the STL's mt19937 generator // Compare result with the STL's mt19937 generator
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::uniform_int_distribution<uint16_t> dis; std::uniform_int_distribution<uint16_t> dis;
std::vector<uint16_t> rand_samples(sample_size); std::vector<uint16_t> gold_rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
} }
TEST_CASE(random_uniform_dyn_test) TEST_CASE(random_uniform_dyn_test)
...@@ -141,9 +143,9 @@ TEST_CASE(random_uniform_dyn_test) ...@@ -141,9 +143,9 @@ TEST_CASE(random_uniform_dyn_test)
// Compare result with the STL's mt19937 generator // Compare result with the STL's mt19937 generator
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> gold_rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
} }
TEST_CASE(random_uniform_and_seed_test) TEST_CASE(random_uniform_and_seed_test)
......
...@@ -43,7 +43,7 @@ TEST_CASE(recip_test) ...@@ -43,7 +43,7 @@ TEST_CASE(recip_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 10.0f, 2.0f}; std::vector<float> gold = {-2.0f, 10.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(recip_dyn_test) TEST_CASE(recip_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(recip_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(recip_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 10.0f, 2.0f}; std::vector<float> gold = {-2.0f, 10.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -64,7 +64,7 @@ TEST_CASE(reduce_max_dynamic_axis0) ...@@ -64,7 +64,7 @@ TEST_CASE(reduce_max_dynamic_axis0)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {6, 7, 8, 9, 10}; std::vector<float> gold = {6, 7, 8, 9, 10};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reduce_max_axis01) TEST_CASE(reduce_max_axis01)
......
...@@ -42,7 +42,7 @@ TEST_CASE(relu_test) ...@@ -42,7 +42,7 @@ TEST_CASE(relu_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.f, 0.f, 1.f}; std::vector<float> gold = {0.f, 0.f, 1.f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(relu_dyn_test) TEST_CASE(relu_dyn_test)
...@@ -63,5 +63,5 @@ TEST_CASE(relu_dyn_test) ...@@ -63,5 +63,5 @@ TEST_CASE(relu_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.f, 0.f, 1.f}; std::vector<float> gold = {0.f, 0.f, 1.f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -44,7 +44,7 @@ TEST_CASE(reshape_lazy_test0) ...@@ -44,7 +44,7 @@ TEST_CASE(reshape_lazy_test0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
} }
TEST_CASE(reshape_lazy_test1) TEST_CASE(reshape_lazy_test1)
...@@ -61,7 +61,7 @@ TEST_CASE(reshape_lazy_test1) ...@@ -61,7 +61,7 @@ TEST_CASE(reshape_lazy_test1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
} }
TEST_CASE(reshape_lazy_test2) TEST_CASE(reshape_lazy_test2)
...@@ -78,7 +78,7 @@ TEST_CASE(reshape_lazy_test2) ...@@ -78,7 +78,7 @@ TEST_CASE(reshape_lazy_test2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
} }
TEST_CASE(reshape_lazy_dyn_test) TEST_CASE(reshape_lazy_dyn_test)
...@@ -99,58 +99,58 @@ TEST_CASE(reshape_lazy_dyn_test) ...@@ -99,58 +99,58 @@ TEST_CASE(reshape_lazy_dyn_test)
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, data));
} }
TEST_CASE(reshape_test0) TEST_CASE(reshape_test0)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> gold(24);
std::iota(data.begin(), data.end(), -3); std::iota(gold.begin(), gold.end(), -3);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {8, 3, 1, 1}; std::vector<int64_t> new_shape = {8, 3, 1, 1};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_test1) TEST_CASE(reshape_test1)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> gold(24);
std::iota(data.begin(), data.end(), -3); std::iota(gold.begin(), gold.end(), -3);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {1, 3, 4, 2}; std::vector<int64_t> new_shape = {1, 3, 4, 2};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_test2) TEST_CASE(reshape_test2)
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> gold(24);
std::iota(data.begin(), data.end(), -3); std::iota(gold.begin(), gold.end(), -3);
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_literal(migraphx::literal{a_shape, data}); auto l = mm->add_literal(migraphx::literal{a_shape, gold});
std::vector<int64_t> new_shape = {1, 2, 3, 4}; std::vector<int64_t> new_shape = {1, 2, 3, 4};
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), l);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reshape_dyn_test) TEST_CASE(reshape_dyn_test)
...@@ -163,13 +163,13 @@ TEST_CASE(reshape_dyn_test) ...@@ -163,13 +163,13 @@ TEST_CASE(reshape_dyn_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input); mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
std::vector<float> data(48); std::vector<float> gold(48);
std::iota(data.begin(), data.end(), -3); std::iota(gold.begin(), gold.end(), -3);
migraphx::parameter_map params; migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}}; migraphx::shape input_fixed_shape{migraphx::shape::float_type, {2, 24, 1, 1}};
params["X"] = migraphx::argument(input_fixed_shape, data.data()); params["X"] = migraphx::argument(input_fixed_shape, gold.data());
auto result = p.eval(params).back(); auto result = p.eval(params).back();
std::vector<float> results_vector{}; std::vector<float> results_vector{};
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -44,9 +44,9 @@ TEST_CASE(reverse_test_axis0) ...@@ -44,9 +44,9 @@ TEST_CASE(reverse_test_axis0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data; std::vector<float> gold = data;
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16); std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_range(results_vector, target_data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reverse_test_axis1) TEST_CASE(reverse_test_axis1)
...@@ -63,10 +63,10 @@ TEST_CASE(reverse_test_axis1) ...@@ -63,10 +63,10 @@ TEST_CASE(reverse_test_axis1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data; std::vector<float> gold = data;
std::reverse(target_data.begin(), target_data.begin() + 16); std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end()); std::reverse(gold.end() - 16, gold.end());
EXPECT(migraphx::verify::verify_range(results_vector, target_data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(reverse_test_axis10) TEST_CASE(reverse_test_axis10)
...@@ -83,9 +83,9 @@ TEST_CASE(reverse_test_axis10) ...@@ -83,9 +83,9 @@ TEST_CASE(reverse_test_axis10)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> target_data = data; std::vector<float> gold = data;
std::reverse(target_data.begin(), target_data.begin() + 16); std::reverse(gold.begin(), gold.begin() + 16);
std::reverse(target_data.end() - 16, target_data.end()); std::reverse(gold.end() - 16, gold.end());
std::swap_ranges(target_data.begin(), target_data.begin() + 16, target_data.begin() + 16); std::swap_ranges(gold.begin(), gold.begin() + 16, gold.begin() + 16);
EXPECT(migraphx::verify::verify_range(results_vector, target_data)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -145,8 +145,8 @@ TEST_CASE(rnn_forward) ...@@ -145,8 +145,8 @@ TEST_CASE(rnn_forward)
-0.16477929, -0.16477929,
-0.11893477}; -0.11893477};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
{ {
...@@ -206,8 +206,8 @@ TEST_CASE(rnn_forward) ...@@ -206,8 +206,8 @@ TEST_CASE(rnn_forward)
0.44193283, 0.44193283,
-0.16477929, -0.16477929,
-0.11893477}; -0.11893477};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
{ {
...@@ -266,8 +266,8 @@ TEST_CASE(rnn_forward) ...@@ -266,8 +266,8 @@ TEST_CASE(rnn_forward)
0}; 0};
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736}; 0.034457, 0.191679, -0.394683, -0.308897, -0.371446, 0.317082, 0.131042, -0.18736};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 3 args // 3 args
...@@ -297,7 +297,7 @@ TEST_CASE(rnn_forward) ...@@ -297,7 +297,7 @@ TEST_CASE(rnn_forward)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// seq_len = 1 // seq_len = 1
...@@ -344,7 +344,7 @@ TEST_CASE(rnn_forward) ...@@ -344,7 +344,7 @@ TEST_CASE(rnn_forward)
0.31708236, 0.31708236,
0.13104209, 0.13104209,
-0.18736027}; -0.18736027};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -438,7 +438,7 @@ TEST_CASE(rnn_reverse) ...@@ -438,7 +438,7 @@ TEST_CASE(rnn_reverse)
0.46251031, 0.46251031,
-0.20639211, -0.20639211,
0.37488942}; 0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// rnn last output as program output // rnn last output as program output
...@@ -481,7 +481,7 @@ TEST_CASE(rnn_reverse) ...@@ -481,7 +481,7 @@ TEST_CASE(rnn_reverse)
0.44124447, 0.44124447,
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // rnn hidden states and last hidden state output as program outputs
...@@ -544,8 +544,8 @@ TEST_CASE(rnn_reverse) ...@@ -544,8 +544,8 @@ TEST_CASE(rnn_reverse)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// rnn hidden states and last hidden state output as program outputs // rnn hidden states and last hidden state output as program outputs
...@@ -606,8 +606,8 @@ TEST_CASE(rnn_reverse) ...@@ -606,8 +606,8 @@ TEST_CASE(rnn_reverse)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
-0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889}; -0.293853, 0.167968, 0.51076, 0.402587, -0.0070999, 0.46251, -0.206392, 0.374889};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
} }
...@@ -718,8 +718,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -718,8 +718,8 @@ TEST_CASE(rnn_bidirectional)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// last rnn output for program output // last rnn output for program output
...@@ -784,8 +784,8 @@ TEST_CASE(rnn_bidirectional) ...@@ -784,8 +784,8 @@ TEST_CASE(rnn_bidirectional)
0.143656, 0.143656,
0.148037}; 0.148037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// 4 args // 4 args
...@@ -835,7 +835,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -835,7 +835,7 @@ TEST_CASE(rnn_bidirectional)
0.14365635, 0.14365635,
0.14803654}; 0.14803654};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// 3 args // 3 args
...@@ -870,7 +870,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -870,7 +870,7 @@ TEST_CASE(rnn_bidirectional)
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.}; 0., 0., 0., 0., 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
} }
// concatenation of hidden state for program output // concatenation of hidden state for program output
...@@ -923,7 +923,7 @@ TEST_CASE(rnn_bidirectional) ...@@ -923,7 +923,7 @@ TEST_CASE(rnn_bidirectional)
-0.20639211, -0.20639211,
0.37488942}; 0.37488942};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1008,7 +1008,10 @@ TEST_CASE(rnn_fp16) ...@@ -1008,7 +1008,10 @@ TEST_CASE(rnn_fp16)
std::vector<float> last_output_data_gold{ std::vector<float> last_output_data_gold{
0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range_with_tolerance(
last_output_data,
migraphx::verify::expected{last_output_data_gold},
migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(gru_forward) TEST_CASE(gru_forward)
...@@ -1106,7 +1109,7 @@ TEST_CASE(gru_forward) ...@@ -1106,7 +1109,7 @@ TEST_CASE(gru_forward)
0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787, 0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787,
-0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// last output for output // last output for output
...@@ -1152,7 +1155,7 @@ TEST_CASE(gru_forward) ...@@ -1152,7 +1155,7 @@ TEST_CASE(gru_forward)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// two rnn_last_hs_output operators after gru // two rnn_last_hs_output operators after gru
...@@ -1199,7 +1202,7 @@ TEST_CASE(gru_forward) ...@@ -1199,7 +1202,7 @@ TEST_CASE(gru_forward)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -1245,7 +1248,7 @@ TEST_CASE(gru_forward) ...@@ -1245,7 +1248,7 @@ TEST_CASE(gru_forward)
0.6014447, 0.6014447,
0.43445644}; 0.43445644};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1330,7 +1333,7 @@ TEST_CASE(gru_forward_args) ...@@ -1330,7 +1333,7 @@ TEST_CASE(gru_forward_args)
-0.232523, 0.00214573, 0.231693, -0.160475, -0.518952, -0.232523, 0.00214573, 0.231693, -0.160475, -0.518952,
0.0467166, 0.12327, -0.374162, 0.137778, 0.251976}; 0.0467166, 0.12327, -0.374162, 0.137778, 0.251976};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (bias is used) // 4 args (bias is used)
...@@ -1373,7 +1376,7 @@ TEST_CASE(gru_forward_args) ...@@ -1373,7 +1376,7 @@ TEST_CASE(gru_forward_args)
-0.416866, 0.377186, 0.32922, 0.162214, -0.519973, -0.416866, 0.377186, 0.32922, 0.162214, -0.519973,
-0.140072, 0.465076, -0.229563, 0.500164, 0.195166}; -0.140072, 0.465076, -0.229563, 0.500164, 0.195166};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (ih is used) // 4 args (ih is used)
...@@ -1417,7 +1420,7 @@ TEST_CASE(gru_forward_args) ...@@ -1417,7 +1420,7 @@ TEST_CASE(gru_forward_args)
-0.197, 0.0885705, 0.269396, -0.0414511, -0.515137, -0.197, 0.0885705, 0.269396, -0.0414511, -0.515137,
-0.03075, 0.158326, -0.296488, 0.177983, 0.519498}; -0.03075, 0.158326, -0.296488, 0.177983, 0.519498};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1519,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1519,7 +1522,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.51757574, 0.51757574,
0.50380427}; 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (sigmoid) specified // 1 activation function (sigmoid) specified
...@@ -1560,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1560,7 +1563,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663, 0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663,
0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278}; 0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (tanh) specified // 1 activation function (tanh) specified
...@@ -1605,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1605,7 +1608,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.65615714, 0.65615714,
0.53612584}; 0.53612584};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// seq length of 1 // seq length of 1
...@@ -1655,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1655,7 +1658,7 @@ TEST_CASE(gru_forward_actv_funcs)
0.6104771, 0.6104771,
0.79759157}; 0.79759157};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -1771,8 +1774,8 @@ TEST_CASE(gru_reverse) ...@@ -1771,8 +1774,8 @@ TEST_CASE(gru_reverse)
0.55703, 0.55703,
0.54711}; 0.54711};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// variable input sequence length // variable input sequence length
...@@ -1832,8 +1835,8 @@ TEST_CASE(gru_reverse) ...@@ -1832,8 +1835,8 @@ TEST_CASE(gru_reverse)
0.558397, 0.558397,
0.664423}; 0.664423};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -1879,7 +1882,7 @@ TEST_CASE(gru_reverse) ...@@ -1879,7 +1882,7 @@ TEST_CASE(gru_reverse)
0.646604, 0.646604,
0.463943}; 0.463943};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// no activation function specified, so default is used. // no activation function specified, so default is used.
...@@ -1918,7 +1921,7 @@ TEST_CASE(gru_reverse) ...@@ -1918,7 +1921,7 @@ TEST_CASE(gru_reverse)
-0.329512, 0.476095, 0.284044, 0.392077, -0.369226, -0.329512, 0.476095, 0.284044, 0.392077, -0.369226,
-0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; -0.3275, -0.027301, 0.143774, 0.655686, 0.782831};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// seq length of 1 // seq length of 1
...@@ -1968,7 +1971,7 @@ TEST_CASE(gru_reverse) ...@@ -1968,7 +1971,7 @@ TEST_CASE(gru_reverse)
0.610477, 0.610477,
0.797592}; 0.797592};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2099,8 +2102,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2099,8 +2102,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// same input sequence length, but shorter than max squence length // same input sequence length, but shorter than max squence length
...@@ -2168,8 +2171,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2168,8 +2171,8 @@ TEST_CASE(gru_bidirectional)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// variable input sequence lengths // variable input sequence lengths
...@@ -2227,8 +2230,8 @@ TEST_CASE(gru_bidirectional) ...@@ -2227,8 +2230,8 @@ TEST_CASE(gru_bidirectional)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
0.182457, 0.304506, 0.313825, 0.397697, 0.300873}; 0.182457, 0.304506, 0.313825, 0.397697, 0.300873};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
} }
// last output for output, linear_before_reset = 0 // last output for output, linear_before_reset = 0
...@@ -2268,7 +2271,7 @@ TEST_CASE(gru_bidirectional) ...@@ -2268,7 +2271,7 @@ TEST_CASE(gru_bidirectional)
-0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289, -0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289,
-0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474}; -0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2370,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2370,7 +2373,7 @@ TEST_CASE(gru_bidirectional_args)
0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407,
0.198708, 0.0695644, 0.211621, 0.00246037}; 0.198708, 0.0695644, 0.211621, 0.00246037};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (bias is used) // 4 args (bias is used)
...@@ -2421,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2421,7 +2424,7 @@ TEST_CASE(gru_bidirectional_args)
0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008, 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008,
0.248674, -0.0295413, 0.291437, -0.165005}; 0.248674, -0.0295413, 0.291437, -0.165005};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 args (ih is used) // 4 args (ih is used)
...@@ -2469,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args) ...@@ -2469,7 +2472,7 @@ TEST_CASE(gru_bidirectional_args)
0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354, 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354,
0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917, 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917,
-0.0339407, 0.413089, 0.721238, 0.431879}; -0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2583,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2583,7 +2586,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305,
0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; 0.381317, 0.468983, 0.230557, 0.348021, 0.180229};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (sigmoid) specified // 1 activation function (sigmoid) specified
...@@ -2626,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2626,7 +2629,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275, 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275,
0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646, 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646,
0.132732, 0.477083, 0.802206, 0.626802}; 0.132732, 0.477083, 0.802206, 0.626802};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 1 activation function (tanh) specified // 1 activation function (tanh) specified
...@@ -2670,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2670,7 +2673,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419, 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419,
0.759629, 0.000258222, 0.350835, -0.682684}; 0.759629, 0.000258222, 0.350835, -0.682684};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 3 activation functions specified // 3 activation functions specified
...@@ -2710,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2710,7 +2713,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
1.15142, 0.457633, 0.300962, 0.361245, 0.666199, 1.15142, 0.457633, 0.300962, 0.361245, 0.666199,
0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473,
0.861394, 0.560799, -0.101768, 0.145142, 0.128956}; 0.861394, 0.560799, -0.101768, 0.145142, 0.128956};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// 4 activation functions all specified // 4 activation functions all specified
...@@ -2758,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2758,7 +2761,7 @@ TEST_CASE(gru_bidirectional_actv_funcs)
0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665,
0.079043, 0.322652, 0.752701, 0.243775}; 0.079043, 0.322652, 0.752701, 0.243775};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -2873,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1) ...@@ -2873,7 +2876,7 @@ TEST_CASE(gru_bidirectional_seq_1)
-0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078,
-0.144492, -0.0115366, 0.409153, 0.487015, 0.550755}; -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
TEST_CASE(gru_fp16) TEST_CASE(gru_fp16)
...@@ -2983,7 +2986,8 @@ TEST_CASE(gru_fp16) ...@@ -2983,7 +2986,8 @@ TEST_CASE(gru_fp16)
-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873,
-0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4)); EXPECT(migraphx::verify::verify_range_with_tolerance(
hs_data, migraphx::verify::expected{hs_data_gold}, migraphx::verify::tolerance{0.005}));
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
...@@ -3114,7 +3118,7 @@ TEST_CASE(lstm_forward) ...@@ -3114,7 +3118,7 @@ TEST_CASE(lstm_forward)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// forward, last_output as program output // forward, last_output as program output
...@@ -3167,7 +3171,7 @@ TEST_CASE(lstm_forward) ...@@ -3167,7 +3171,7 @@ TEST_CASE(lstm_forward)
0.0342236, 0.0342236,
-0.198664, -0.198664,
0.0702607}; 0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// forward, last_cell_output as program output // forward, last_cell_output as program output
...@@ -3220,7 +3224,7 @@ TEST_CASE(lstm_forward) ...@@ -3220,7 +3224,7 @@ TEST_CASE(lstm_forward)
0.078598, 0.078598,
-0.64457, -0.64457,
0.119811}; 0.119811};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -3342,7 +3346,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3342,7 +3346,7 @@ TEST_CASE(lstm_forward_more)
0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774}; 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// forward, 8 args // forward, 8 args
...@@ -3391,7 +3395,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3391,7 +3395,7 @@ TEST_CASE(lstm_forward_more)
0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723}; 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
// forward, last_output as program output, sequence length shorter // forward, last_output as program output, sequence length shorter
...@@ -3453,7 +3457,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3453,7 +3457,7 @@ TEST_CASE(lstm_forward_more)
0.0342236, 0.0342236,
-0.198664, -0.198664,
0.0702607}; 0.0702607};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// seq_len = 1 // seq_len = 1
...@@ -3511,7 +3515,7 @@ TEST_CASE(lstm_forward_more) ...@@ -3511,7 +3515,7 @@ TEST_CASE(lstm_forward_more)
-0.121195, -0.121195,
-0.4065, -0.4065,
-0.252054}; -0.252054};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
} }
} }
...@@ -3641,7 +3645,7 @@ TEST_CASE(lstm_reverse) ...@@ -3641,7 +3645,7 @@ TEST_CASE(lstm_reverse)
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, sequence lengths are the same, but less than max_seq_lens // reverse, sequence lengths are the same, but less than max_seq_lens
...@@ -3699,7 +3703,7 @@ TEST_CASE(lstm_reverse) ...@@ -3699,7 +3703,7 @@ TEST_CASE(lstm_reverse)
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0}; 0.0, 0.0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// variable sequence lengths // variable sequence lengths
...@@ -3749,7 +3753,7 @@ TEST_CASE(lstm_reverse) ...@@ -3749,7 +3753,7 @@ TEST_CASE(lstm_reverse)
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0}; 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, last cell output as program output // reverse, 3 args, last cell output as program output
...@@ -3791,7 +3795,7 @@ TEST_CASE(lstm_reverse) ...@@ -3791,7 +3795,7 @@ TEST_CASE(lstm_reverse)
0.141613, 0.141613,
0.348002, 0.348002,
0.667298}; 0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, 0 actv function // reverse, 3 args, 0 actv function
...@@ -3830,7 +3834,7 @@ TEST_CASE(lstm_reverse) ...@@ -3830,7 +3834,7 @@ TEST_CASE(lstm_reverse)
0.141613, 0.141613,
0.348002, 0.348002,
0.667298}; 0.667298};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -3948,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3948,7 +3952,7 @@ TEST_CASE(lstm_reverse_actv)
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, 2 actv functions // reverse, 3 args, 2 actv functions
...@@ -3989,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -3989,7 +3993,7 @@ TEST_CASE(lstm_reverse_actv)
0.233866, 0.233866,
0.48646, 0.48646,
0.481844}; 0.481844};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
...@@ -4035,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv) ...@@ -4035,7 +4039,7 @@ TEST_CASE(lstm_reverse_actv)
0.070535, 0.070535,
0.327809, 0.327809,
0.407388}; 0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4162,7 +4166,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4162,7 +4166,7 @@ TEST_CASE(lstm_bidirectional)
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// last hidden state as program output // last hidden state as program output
...@@ -4205,7 +4209,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4205,7 +4209,7 @@ TEST_CASE(lstm_bidirectional)
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// last cell output as program output // last cell output as program output
...@@ -4248,7 +4252,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4248,7 +4252,7 @@ TEST_CASE(lstm_bidirectional)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, concatenation of hidden states as program output // 3 args, concatenation of hidden states as program output
...@@ -4291,7 +4295,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4291,7 +4295,7 @@ TEST_CASE(lstm_bidirectional)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// sequence length is 1, contenation of hidden state as program output // sequence length is 1, contenation of hidden state as program output
...@@ -4328,7 +4332,7 @@ TEST_CASE(lstm_bidirectional) ...@@ -4328,7 +4332,7 @@ TEST_CASE(lstm_bidirectional)
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4480,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) ...@@ -4480,9 +4484,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242, 0.391174, 0.0308845, -0.561745, 0.0730323, -0.326822, 0.301121, 0.219523, 0.415242,
2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436}; 2.08242, 0.442513, 0.187127, 0.0577626, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_range(last_output_data, last_output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_range(last_cell_data, last_cell_data_gold)); EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
} }
// last cell output as program output // last cell output as program output
...@@ -4567,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) ...@@ -4567,9 +4571,9 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_range(lho_data, lho_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_range(lco_data, lco_data_gold)); EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
} }
} }
...@@ -4654,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4654,7 +4658,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 1 actv func // 3 args, 1 actv func
...@@ -4694,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4694,7 +4698,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563, 0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563,
0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634, 0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 2 actv func // 3 args, 2 actv func
...@@ -4727,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4727,7 +4731,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 4 actv func // 3 args, 4 actv func
...@@ -4763,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4763,7 +4767,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186}; 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 5 actv func // 3 args, 5 actv func
...@@ -4799,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4799,7 +4803,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
// 3 args, 6 actv func // 3 args, 6 actv func
...@@ -4836,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -4836,7 +4840,7 @@ TEST_CASE(lstm_bidirectional_actv_func)
0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
} }
} }
...@@ -4981,5 +4985,5 @@ TEST_CASE(lstm_fp16) ...@@ -4981,5 +4985,5 @@ TEST_CASE(lstm_fp16)
0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_range(hs_data, hs_data_gold, 5e4)); EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold, 5e4));
} }
...@@ -80,7 +80,7 @@ TEST_CASE(roialign_out_of_bound_test) ...@@ -80,7 +80,7 @@ TEST_CASE(roialign_out_of_bound_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0f, 0.0f, 0.0f}; std::vector<float> gold = {0.0f, 0.0f, 0.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -150,7 +150,7 @@ TEST_CASE(roialign_test) ...@@ -150,7 +150,7 @@ TEST_CASE(roialign_test)
0.256580025, 0.214098021, 0.279604018, 0.360000014, 0.436488032, 0.350427985, 0.256580025, 0.214098021, 0.279604018, 0.360000014, 0.436488032, 0.350427985,
0.288755983, 0.366139978, 0.234920025}; 0.288755983, 0.366139978, 0.234920025};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
{ {
...@@ -171,7 +171,7 @@ TEST_CASE(roialign_test) ...@@ -171,7 +171,7 @@ TEST_CASE(roialign_test)
0.929997, 0.66257, 0.561664, 0.481275, 0.495449, 0.666306, 0.663573, 0.372107, 0.929997, 0.66257, 0.561664, 0.481275, 0.495449, 0.666306, 0.663573, 0.372107,
0.205603, 0.192776, 0.247849}; 0.205603, 0.192776, 0.247849};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
{ {
...@@ -192,6 +192,6 @@ TEST_CASE(roialign_test) ...@@ -192,6 +192,6 @@ TEST_CASE(roialign_test)
0.44757, 0.351855, 0.342265, 0.244475, 0.274841, 0.553644, 0.607176, 0.202392, 0.44757, 0.351855, 0.342265, 0.244475, 0.274841, 0.553644, 0.607176, 0.202392,
0.07425, 0.066087, 0.126279}; 0.07425, 0.066087, 0.126279};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
} }
...@@ -43,7 +43,7 @@ TEST_CASE(round_test) ...@@ -43,7 +43,7 @@ TEST_CASE(round_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(round_dyn_test) TEST_CASE(round_dyn_test)
...@@ -64,5 +64,5 @@ TEST_CASE(round_dyn_test) ...@@ -64,5 +64,5 @@ TEST_CASE(round_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0}; std::vector<float> gold = {1.0, 2.0, 2.0, -1.0, -2.0, -2.0, 0.0, 2.0, -2.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -42,7 +42,7 @@ TEST_CASE(rsqrt_test) ...@@ -42,7 +42,7 @@ TEST_CASE(rsqrt_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(rsqrt_dyn_test) TEST_CASE(rsqrt_dyn_test)
...@@ -63,5 +63,5 @@ TEST_CASE(rsqrt_dyn_test) ...@@ -63,5 +63,5 @@ TEST_CASE(rsqrt_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.5, 0.25, 0.125}; std::vector<float> gold = {0.5, 0.25, 0.125};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -77,5 +77,5 @@ TEST_CASE(imagescaler_test) ...@@ -77,5 +77,5 @@ TEST_CASE(imagescaler_test)
0.53, 0.53,
0.73, 0.73,
0.93}; 0.93};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment