Commit c88bf676 authored by Umang Yadav's avatar Umang Yadav
Browse files

add comments and use expected as the argument for the verify_args

parent 20919ff7
......@@ -225,6 +225,10 @@ struct tolerance
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
......
......@@ -33,12 +33,12 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT bool verify_args_with_threshold(const std::string& name,
const argument& target_arg,
const argument& ref_arg,
const verify::expected<argument>& ref_arg,
verify::tolerance);
MIGRAPHX_EXPORT bool verify_args(const std::string& name,
const argument& target_arg,
const argument& ref_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -29,11 +29,11 @@ inline namespace MIGRAPHX_INLINE_NS {
bool verify_args_with_threshold(const std::string& name,
const argument& target_arg,
const argument& ref_arg,
const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{
bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) {
visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double rms_error;
passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
......@@ -96,7 +96,7 @@ bool verify_args_with_threshold(const std::string& name,
bool verify_args(const std::string& name,
const argument& target_arg,
const argument& ref_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment