Commit c88bf676 authored by Umang Yadav's avatar Umang Yadav
Browse files

add comments and use expected as the argument for the verify_args

parent 20919ff7
...@@ -225,6 +225,10 @@ struct tolerance ...@@ -225,6 +225,10 @@ struct tolerance
double rtol = 0.001; double rtol = 0.001;
}; };
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2> template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols) bool allclose(const R1& r1, const R2& r2, tolerance tols)
{ {
......
...@@ -33,12 +33,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,12 +33,12 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name, MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
verify::tolerance); verify::tolerance);
MIGRAPHX_EXPORT bool verify_args(const std::string& name, MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80); std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -29,11 +29,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -29,11 +29,11 @@ inline namespace MIGRAPHX_INLINE_NS {
bool verify_args_with_threshold(const std::string& name, bool verify_args_with_threshold(const std::string& name,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
verify::tolerance tols) verify::tolerance tols)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double rms_error; double rms_error;
passed = passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error); verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
...@@ -96,7 +96,7 @@ bool verify_args_with_threshold(const std::string& name, ...@@ -96,7 +96,7 @@ bool verify_args_with_threshold(const std::string& name,
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& target_arg, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
std::size_t tolerance) std::size_t tolerance)
{ {
double rms_tol = 0.001; double rms_tol = 0.001;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment