Commit c2bafa5d authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'ck-flash-attn' of...

Merge branch 'ck-flash-attn' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into ck-flash-attn
parents 370b2cce 250d3c87
...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 or n2; return n1 or n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_or_dyn_test) TEST_CASE(logical_or_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 or n2; }); [](bool n1, bool n2) -> bool { return n1 or n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -49,7 +49,7 @@ TEST_CASE(logical_xor_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_xor_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 ^ n2; return n1 ^ n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_xor_dyn_test) TEST_CASE(logical_xor_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_xor_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_xor_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 ^ n2; }); [](bool n1, bool n2) -> bool { return n1 ^ n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -61,7 +61,7 @@ TEST_CASE(logsoftmax_test_axis_0) ...@@ -61,7 +61,7 @@ TEST_CASE(logsoftmax_test_axis_0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_1) TEST_CASE(logsoftmax_test_axis_1)
...@@ -95,7 +95,7 @@ TEST_CASE(logsoftmax_test_axis_1) ...@@ -95,7 +95,7 @@ TEST_CASE(logsoftmax_test_axis_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_2) TEST_CASE(logsoftmax_test_axis_2)
...@@ -129,7 +129,7 @@ TEST_CASE(logsoftmax_test_axis_2) ...@@ -129,7 +129,7 @@ TEST_CASE(logsoftmax_test_axis_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_3) TEST_CASE(logsoftmax_test_axis_3)
...@@ -163,5 +163,5 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -163,5 +163,5 @@ TEST_CASE(logsoftmax_test_axis_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
...@@ -43,5 +43,5 @@ TEST_CASE(lrn_test) ...@@ -43,5 +43,5 @@ TEST_CASE(lrn_test)
std::vector<float> results_vector(5); std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075}; std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(max_test) ...@@ -45,7 +45,7 @@ TEST_CASE(max_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{7, 8, 9}; std::vector<float> gold{7, 8, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(max_dyn_test) TEST_CASE(max_dyn_test)
...@@ -73,5 +73,5 @@ TEST_CASE(max_dyn_test) ...@@ -73,5 +73,5 @@ TEST_CASE(max_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{7, 8, 9}; std::vector<float> gold{7, 8, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(min_test) ...@@ -45,7 +45,7 @@ TEST_CASE(min_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 4, 3}; std::vector<float> gold{1, 4, 3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(min_dyn_test) TEST_CASE(min_dyn_test)
...@@ -73,5 +73,5 @@ TEST_CASE(min_dyn_test) ...@@ -73,5 +73,5 @@ TEST_CASE(min_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 4, 3}; std::vector<float> gold{1, 4, 3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(mod_test) ...@@ -45,7 +45,7 @@ TEST_CASE(mod_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 2}; std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mod_dyn_test) TEST_CASE(mod_dyn_test)
...@@ -73,7 +73,7 @@ TEST_CASE(mod_dyn_test) ...@@ -73,7 +73,7 @@ TEST_CASE(mod_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 2}; std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mod_float_test) TEST_CASE(mod_float_test)
...@@ -92,5 +92,5 @@ TEST_CASE(mod_float_test) ...@@ -92,5 +92,5 @@ TEST_CASE(mod_float_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0f, 2.5f, 2.0f}; std::vector<float> gold{1.0f, 2.5f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -49,7 +49,7 @@ TEST_CASE(mul_test) ...@@ -49,7 +49,7 @@ TEST_CASE(mul_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float {
return n1 * n2; return n1 * n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mul_dyn_test) TEST_CASE(mul_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(mul_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(mul_dyn_test)
y_data.begin(), y_data.begin(),
gold.begin(), gold.begin(),
[](float n1, float n2) -> float { return n1 * n2; }); [](float n1, float n2) -> float { return n1 * n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test) ...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(neg_test) ...@@ -45,7 +45,7 @@ TEST_CASE(neg_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>()); std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(neg_dyn_test) TEST_CASE(neg_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(neg_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(neg_dyn_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = a; std::vector<float> gold = a;
std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>()); std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
...@@ -62,7 +62,7 @@ TEST_CASE(nms_dyn_out_test) ...@@ -62,7 +62,7 @@ TEST_CASE(nms_dyn_out_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_batch_test) TEST_CASE(nms_dyn_batch_test)
...@@ -108,7 +108,7 @@ TEST_CASE(nms_dyn_batch_test) ...@@ -108,7 +108,7 @@ TEST_CASE(nms_dyn_batch_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 1, 0, 3, 1, 0, 0, 1, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 1, 0, 3, 1, 0, 0, 1, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_boxes_test) TEST_CASE(nms_dyn_boxes_test)
...@@ -151,7 +151,7 @@ TEST_CASE(nms_dyn_boxes_test) ...@@ -151,7 +151,7 @@ TEST_CASE(nms_dyn_boxes_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_dyn_classes_test) TEST_CASE(nms_dyn_classes_test)
...@@ -195,7 +195,7 @@ TEST_CASE(nms_dyn_classes_test) ...@@ -195,7 +195,7 @@ TEST_CASE(nms_dyn_classes_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 1, 3, 0, 1, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 1, 3, 0, 1, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_not_center_test) TEST_CASE(nms_not_center_test)
...@@ -231,7 +231,7 @@ TEST_CASE(nms_not_center_test) ...@@ -231,7 +231,7 @@ TEST_CASE(nms_not_center_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_test) TEST_CASE(nms_test)
...@@ -265,7 +265,7 @@ TEST_CASE(nms_test) ...@@ -265,7 +265,7 @@ TEST_CASE(nms_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_transpose1_test) TEST_CASE(nms_transpose1_test)
...@@ -303,7 +303,7 @@ TEST_CASE(nms_transpose1_test) ...@@ -303,7 +303,7 @@ TEST_CASE(nms_transpose1_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
TEST_CASE(nms_transpose2_test) TEST_CASE(nms_transpose2_test)
...@@ -341,5 +341,5 @@ TEST_CASE(nms_transpose2_test) ...@@ -341,5 +341,5 @@ TEST_CASE(nms_transpose2_test)
std::vector<int64_t> result; std::vector<int64_t> result;
output.visit([&](auto out) { result.assign(out.begin(), out.end()); }); output.visit([&](auto out) { result.assign(out.begin(), out.end()); });
std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0}; std::vector<int64_t> gold = {0, 0, 3, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result, gold)); EXPECT(migraphx::verify::verify_rms_range(result, gold));
} }
...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test) ...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0}; 1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32) ...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_test_bool) TEST_CASE(not_test_bool)
...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool) ...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(data.size()); std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; }); std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_dyn_test) TEST_CASE(not_dyn_test)
...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test) ...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -42,7 +42,7 @@ TEST_CASE(pad_test) ...@@ -42,7 +42,7 @@ TEST_CASE(pad_test)
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_asym) TEST_CASE(pad_test_asym)
...@@ -57,7 +57,7 @@ TEST_CASE(pad_test_asym) ...@@ -57,7 +57,7 @@ TEST_CASE(pad_test_asym)
std::vector<float> results_vector(9); std::vector<float> results_vector(9);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0}; std::vector<float> gold{1, 2, 0, 3, 4, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_highest_half) TEST_CASE(pad_test_highest_half)
...@@ -76,7 +76,7 @@ TEST_CASE(pad_test_highest_half) ...@@ -76,7 +76,7 @@ TEST_CASE(pad_test_highest_half)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::max(); const float x = std::numeric_limits<migraphx::half>::max();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_test_lowest_half) TEST_CASE(pad_test_lowest_half)
...@@ -95,7 +95,7 @@ TEST_CASE(pad_test_lowest_half) ...@@ -95,7 +95,7 @@ TEST_CASE(pad_test_lowest_half)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
const float x = std::numeric_limits<migraphx::half>::lowest(); const float x = std::numeric_limits<migraphx::half>::lowest();
std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x}; std::vector<float> gold{x, x, x, x, x, 1, 2, x, x, 3, 4, x, x, x, x, x};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pad_dyn_test) TEST_CASE(pad_dyn_test)
...@@ -115,5 +115,5 @@ TEST_CASE(pad_dyn_test) ...@@ -115,5 +115,5 @@ TEST_CASE(pad_dyn_test)
std::vector<float> results_vector(16); std::vector<float> results_vector(16);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}; std::vector<float> gold{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test) ...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -51,7 +51,7 @@ TEST_CASE(avgpool_rank3_test) ...@@ -51,7 +51,7 @@ TEST_CASE(avgpool_rank3_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35}; std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_test) TEST_CASE(avgpool_dyn_test)
...@@ -77,7 +77,7 @@ TEST_CASE(avgpool_dyn_test) ...@@ -77,7 +77,7 @@ TEST_CASE(avgpool_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35}; std::vector<float> gold{0.25, 0.3, 0.25, 0.65, 0.7, 0.5, 0.4, 0.4, 0.35};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_pad_test) TEST_CASE(avgpool_dyn_pad_test)
...@@ -105,7 +105,7 @@ TEST_CASE(avgpool_dyn_pad_test) ...@@ -105,7 +105,7 @@ TEST_CASE(avgpool_dyn_pad_test)
std::vector<float> gold{ std::vector<float> gold{
0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6}; 0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_auto_pad_test) TEST_CASE(avgpool_dyn_auto_pad_test)
...@@ -141,7 +141,7 @@ TEST_CASE(avgpool_dyn_auto_pad_test) ...@@ -141,7 +141,7 @@ TEST_CASE(avgpool_dyn_auto_pad_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5, 2.5, 3.5, 3.5}; std::vector<float> gold{2.5, 2.5, 3.5, 3.5};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_auto_pad_1d_test) TEST_CASE(avgpool_dyn_auto_pad_1d_test)
...@@ -175,7 +175,7 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test) ...@@ -175,7 +175,7 @@ TEST_CASE(avgpool_dyn_auto_pad_1d_test)
0.8, 0.65, 0.7, 0.5, 0.8, 0.65, 0.7, 0.5,
0.1, 0.4, 0.4, 0.35}; 0.1, 0.4, 0.4, 0.35};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_dyn_pad_ceil_test) TEST_CASE(avgpool_dyn_pad_ceil_test)
...@@ -212,7 +212,7 @@ TEST_CASE(avgpool_dyn_pad_ceil_test) ...@@ -212,7 +212,7 @@ TEST_CASE(avgpool_dyn_pad_ceil_test)
2.0, 2.5, 2.5, 3.0, 2.0, 2.5, 2.5, 3.0,
3.0, 3.5, 3.5, 4.0}; 3.0, 3.5, 3.5, 4.0};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_rank3_stride2_test) TEST_CASE(avgpool_rank3_stride2_test)
...@@ -245,7 +245,7 @@ TEST_CASE(avgpool_rank3_stride2_test) ...@@ -245,7 +245,7 @@ TEST_CASE(avgpool_rank3_stride2_test)
-0.3442, 1.22005, 1.5295, -0.3442, 1.22005, 1.5295,
0.9965, 0.7854, -0.2915}; 0.9965, 0.7854, -0.2915};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(avgpool_rank5_test) TEST_CASE(avgpool_rank5_test)
...@@ -281,7 +281,7 @@ TEST_CASE(avgpool_rank5_test) ...@@ -281,7 +281,7 @@ TEST_CASE(avgpool_rank5_test)
-0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375, -0.078375, 0.139375, 0.46075, 0.0285, -0.188125, -0.085, 0.378125, -0.085375,
-0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187, -0.04, 0.304125, 0.40775, 0.2835, 0.112375, -0.073375, 0.4355, -0.187,
-0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625}; -0.392625, -0.258375, -0.485875, -0.0345, 0.16125, -0.131875, -0.228375, 0.068625};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
...@@ -302,7 +302,7 @@ TEST_CASE(globalavgpool_test) ...@@ -302,7 +302,7 @@ TEST_CASE(globalavgpool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375}; std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalavgpool_dyn_test) TEST_CASE(globalavgpool_dyn_test)
...@@ -325,7 +325,7 @@ TEST_CASE(globalavgpool_dyn_test) ...@@ -325,7 +325,7 @@ TEST_CASE(globalavgpool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375}; std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globallppool_test) TEST_CASE(globallppool_test)
...@@ -347,7 +347,7 @@ TEST_CASE(globallppool_test) ...@@ -347,7 +347,7 @@ TEST_CASE(globallppool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815}; std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globallppool_dyn_test) TEST_CASE(globallppool_dyn_test)
...@@ -371,7 +371,7 @@ TEST_CASE(globallppool_dyn_test) ...@@ -371,7 +371,7 @@ TEST_CASE(globallppool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815}; std::vector<float> gold{0.5477225575051662, 1.307669683062202, 0.9327379053088815};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalmaxpool_test) TEST_CASE(globalmaxpool_test)
...@@ -392,7 +392,7 @@ TEST_CASE(globalmaxpool_test) ...@@ -392,7 +392,7 @@ TEST_CASE(globalmaxpool_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7}; std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(globalmaxpool_dyn_test) TEST_CASE(globalmaxpool_dyn_test)
...@@ -416,7 +416,7 @@ TEST_CASE(globalmaxpool_dyn_test) ...@@ -416,7 +416,7 @@ TEST_CASE(globalmaxpool_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7}; std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(lppool_l1_norm_test) TEST_CASE(lppool_l1_norm_test)
...@@ -440,7 +440,7 @@ TEST_CASE(lppool_l1_norm_test) ...@@ -440,7 +440,7 @@ TEST_CASE(lppool_l1_norm_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7}; std::vector<float> gold{0.5, 0.6, 0.5, 1.3, 1.4, 1.0, 0.8, 0.8, 0.7};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
// TODO: this tests compliance with a oneDNN rule and a feature that's commented out // TODO: this tests compliance with a oneDNN rule and a feature that's commented out
...@@ -493,7 +493,7 @@ TEST_CASE(lppool_l2_norm_test) ...@@ -493,7 +493,7 @@ TEST_CASE(lppool_l2_norm_test)
0.7071067811865475, 0.7071067811865475,
0.7071067811865475, 0.7071067811865475,
0.6082762530298219}; 0.6082762530298219};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(lppool_dyn_test) TEST_CASE(lppool_dyn_test)
...@@ -526,7 +526,7 @@ TEST_CASE(lppool_dyn_test) ...@@ -526,7 +526,7 @@ TEST_CASE(lppool_dyn_test)
0.7071067811865475, 0.7071067811865475,
0.7071067811865475, 0.7071067811865475,
0.6082762530298219}; 0.6082762530298219};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_test) TEST_CASE(maxpool_test)
...@@ -565,12 +565,6 @@ TEST_CASE(maxpool_test) ...@@ -565,12 +565,6 @@ TEST_CASE(maxpool_test)
-0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746, -0.88503587, 0.6629802, 1.47319221, -1.05829155, -0.97027361, -0.93187737, -1.39954746,
-0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223, -0.52359426, -0.14743951, 1.51522756, 0.2078452, -1.28156149, -1.19363916, -0.78680223,
-0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682}; -0.89094824, 1.30212069, -0.77974445, -0.58411664, 0.48764706, -0.67132682};
std::vector<float> c = {1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 6, 6}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -583,7 +577,15 @@ TEST_CASE(maxpool_test) ...@@ -583,7 +577,15 @@ TEST_CASE(maxpool_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, c)); std::vector<float> gold = {
1.33493888, 1.54562736, 1.22098756, 1.33493888, 1.18358743, 1.99097753,
1.00170159, 1.45862222, 1.39087725, 1.46645272, 1.18943918, -0.01443311,
1.47151589, 2.36277103, 2.24768877, 0.68883753, 0.82949388, 0.71550399,
1.95433736, 2.46601582, 1.53285873, 1.95433736, 1.06763375, 1.4545635,
1.33624589, 1.16736257, 0.6126079, 1.36892557, 2.40126371, 1.53441942,
0.52119428, 2.07681108, 0.88494766, 1.51522756, 0.54275119, 0.6629802};
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_pad_test) TEST_CASE(maxpool_pad_test)
...@@ -591,7 +593,6 @@ TEST_CASE(maxpool_pad_test) ...@@ -591,7 +593,6 @@ TEST_CASE(maxpool_pad_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<float> a = {-6, -5, -4, -3, -5, -1, 0, 1, 2, 3, 4, 5}; std::vector<float> a = {-6, -5, -4, -3, -5, -1, 0, 1, 2, 3, 4, 5};
std::vector<float> c = {-4, -3, -4, -1, 2, 3, 4, 5};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 2}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 2}};
auto al = mm->add_literal(migraphx::literal{a_shape, a}); auto al = mm->add_literal(migraphx::literal{a_shape, a});
mm->add_instruction(migraphx::make_op("pooling", mm->add_instruction(migraphx::make_op("pooling",
...@@ -611,8 +612,8 @@ TEST_CASE(maxpool_pad_test) ...@@ -611,8 +612,8 @@ TEST_CASE(maxpool_pad_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector(8); std::vector<float> results_vector(8);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-4, -3, -4, -1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_range(results_vector, c)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_test0) TEST_CASE(maxpool_rank3_test0)
...@@ -635,7 +636,7 @@ TEST_CASE(maxpool_rank3_test0) ...@@ -635,7 +636,7 @@ TEST_CASE(maxpool_rank3_test0)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6}; std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_test1) TEST_CASE(maxpool_rank3_test1)
...@@ -660,7 +661,7 @@ TEST_CASE(maxpool_rank3_test1) ...@@ -660,7 +661,7 @@ TEST_CASE(maxpool_rank3_test1)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603}; std::vector<float> gold{0.4975, -0.0405, -0.6186, 0.6022, 0.5493, -0.8039, 1.5001, -1.1603};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank3_ceil_test) TEST_CASE(maxpool_rank3_ceil_test)
...@@ -694,7 +695,7 @@ TEST_CASE(maxpool_rank3_ceil_test) ...@@ -694,7 +695,7 @@ TEST_CASE(maxpool_rank3_ceil_test)
0.6022, 1.1925, 0.5493, -0.8039, 0.6022, 1.1925, 0.5493, -0.8039,
0.9907, 1.5001, -1.1603, 1.2556}; 0.9907, 1.5001, -1.1603, 1.2556};
// clang-format on // clang-format on
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_rank5_test) TEST_CASE(maxpool_rank5_test)
...@@ -727,7 +728,7 @@ TEST_CASE(maxpool_rank5_test) ...@@ -727,7 +728,7 @@ TEST_CASE(maxpool_rank5_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859}; std::vector<float> gold{1.5064, 1.3655, 0.9035, 2.6859};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(maxpool_dyn_test) TEST_CASE(maxpool_dyn_test)
...@@ -752,5 +753,5 @@ TEST_CASE(maxpool_dyn_test) ...@@ -752,5 +753,5 @@ TEST_CASE(maxpool_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6}; std::vector<float> gold{0.3, 0.4, 0.4, 0.8, 0.9, 0.9, 0.7, 0.7, 0.6};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -46,7 +46,7 @@ TEST_CASE(pow_test) ...@@ -46,7 +46,7 @@ TEST_CASE(pow_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pow_dyn_test) TEST_CASE(pow_dyn_test)
...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test) ...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -43,7 +43,7 @@ TEST_CASE(prelu_test) ...@@ -43,7 +43,7 @@ TEST_CASE(prelu_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f}; std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(prelu_dyn_test) TEST_CASE(prelu_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(prelu_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(prelu_dyn_test)
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2.0f, 0.0f, 2.0f}; std::vector<float> gold = {-2.0f, 0.0f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -47,7 +47,7 @@ TEST_CASE(quant_conv2d_padding_stride_test) ...@@ -47,7 +47,7 @@ TEST_CASE(quant_conv2d_padding_stride_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = {4521, std::vector<int32_t> gold = {4521,
7014, 7014,
7830, 7830,
11952, 11952,
...@@ -65,7 +65,7 @@ TEST_CASE(quant_conv2d_padding_stride_test) ...@@ -65,7 +65,7 @@ TEST_CASE(quant_conv2d_padding_stride_test)
82746}; 82746};
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(quant_conv2d_padding_test) TEST_CASE(quant_conv2d_padding_test)
...@@ -84,7 +84,7 @@ TEST_CASE(quant_conv2d_padding_test) ...@@ -84,7 +84,7 @@ TEST_CASE(quant_conv2d_padding_test)
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl); migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), al, cl);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = { std::vector<int32_t> gold = {
4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007, 4521, 6753, 7014, 4635, 6858, 10197, 10548, 6939, 7830, 11601, 11952, 7839, 5007,
7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826, 7383, 7590, 4953, 10515, 15987, 16734, 11277, 16821, 25506, 26586, 17874, 19737, 29826,
30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396, 30906, 20718, 13593, 20505, 21198, 14187, 13161, 19281, 19542, 12699, 18522, 27045, 27396,
...@@ -93,7 +93,7 @@ TEST_CASE(quant_conv2d_padding_test) ...@@ -93,7 +93,7 @@ TEST_CASE(quant_conv2d_padding_test)
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(quant_conv2d_test) TEST_CASE(quant_conv2d_test)
...@@ -114,7 +114,7 @@ TEST_CASE(quant_conv2d_test) ...@@ -114,7 +114,7 @@ TEST_CASE(quant_conv2d_test)
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<int32_t> s = {10197, std::vector<int32_t> gold = {10197,
10548, 10548,
11601, 11601,
11952, 11952,
...@@ -133,5 +133,5 @@ TEST_CASE(quant_conv2d_test) ...@@ -133,5 +133,5 @@ TEST_CASE(quant_conv2d_test)
std::vector<int32_t> results_vector; std::vector<int32_t> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -68,7 +68,9 @@ TEST_CASE(random_uniform_test) ...@@ -68,7 +68,9 @@ TEST_CASE(random_uniform_test)
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples, 100)); EXPECT(migraphx::verify::verify_range_with_tolerance(result_vec,
migraphx::verify::expected{rand_samples},
migraphx::verify::tolerance{0.00001}));
} }
TEST_CASE(random_uniform_int_test) TEST_CASE(random_uniform_int_test)
...@@ -102,9 +104,9 @@ TEST_CASE(random_uniform_int_test) ...@@ -102,9 +104,9 @@ TEST_CASE(random_uniform_int_test)
// Compare result with the STL's mt19937 generator // Compare result with the STL's mt19937 generator
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::uniform_int_distribution<uint16_t> dis; std::uniform_int_distribution<uint16_t> dis;
std::vector<uint16_t> rand_samples(sample_size); std::vector<uint16_t> gold_rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
} }
TEST_CASE(random_uniform_dyn_test) TEST_CASE(random_uniform_dyn_test)
...@@ -141,9 +143,9 @@ TEST_CASE(random_uniform_dyn_test) ...@@ -141,9 +143,9 @@ TEST_CASE(random_uniform_dyn_test)
// Compare result with the STL's mt19937 generator // Compare result with the STL's mt19937 generator
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
std::vector<float> rand_samples(sample_size); std::vector<float> gold_rand_samples(sample_size);
std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); std::generate(gold_rand_samples.begin(), gold_rand_samples.end(), [&]() { return dis(gen); });
EXPECT(migraphx::verify::verify_range(result_vec, rand_samples)); EXPECT(migraphx::verify::verify_rms_range(result_vec, gold_rand_samples));
} }
TEST_CASE(random_uniform_and_seed_test) TEST_CASE(random_uniform_and_seed_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment