#ifndef MIGRAPH_GUARD_VERIFY_HPP #define MIGRAPH_GUARD_VERIFY_HPP #include #include #include #include #include #include namespace migraph { // Compute the value of a range template using range_value = std::decay_t().begin())>; struct sum_fn { template auto operator()(T x, U y) const { return x + y; } }; static constexpr sum_fn sum{}; struct max_fn { template static T id(T x) { return x; } template auto operator()(T x, U y) const { return x > y ? x : y; } }; static constexpr max_fn max{}; namespace abs_diff_detail { using std::fabs; struct fn { template auto operator()(T x, U y) const { return fabs(x - y); } }; } // namespace abs_diff_detail static constexpr abs_diff_detail::fn abs_diff{}; struct not_finite_fn { template bool operator()(T x) const { using std::isfinite; return not isfinite(x); } }; static constexpr not_finite_fn not_finite{}; struct compare_mag_fn { template bool operator()(T x, U y) const { using std::fabs; return fabs(x) < fabs(y); } }; static constexpr compare_mag_fn compare_mag{}; struct square_diff_fn { template double operator()(T x, U y) const { return (x - y) * (x - y); } }; static constexpr square_diff_fn square_diff{}; template bool range_empty(R1&& r1) { return r1.begin() == r1.end(); } template auto range_distance(R1&& r1) { return std::distance(r1.begin(), r1.end()); } template bool range_zero(R1&& r1) { return std::all_of(r1.begin(), r1.end(), [](auto x) { return float_equal(x, 0); }); } template T range_product(R1&& r1, R2&& r2, T state, Reducer r, Product p) { return std::inner_product(r1.begin(), r1.end(), r2.begin(), state, r, p); } template std::size_t mismatch_idx(R1&& r1, R2&& r2, Compare compare) { auto p = std::mismatch(r1.begin(), r1.end(), r2.begin(), compare); return std::distance(r1.begin(), p.first); } template long find_idx(R1&& r1, Predicate p) { auto it = std::find_if(r1.begin(), r1.end(), p); if(it == r1.end()) return -1; else return std::distance(r1.begin(), it); } template double max_diff(R1&& r1, R2&& r2) { return range_product(r1, r2, 0.0, max, abs_diff); } template std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff) { return mismatch_idx(r1, r2, [&](auto x, auto y) { auto d = abs_diff(x, y); return float_equal(d, diff); }); } template double rms_range(R1&& r1, R2&& r2) { std::size_t n = range_distance(r1); if(n == range_distance(r2)) { double square_difference = range_product(r1, r2, 0.0, sum_fn{}, square_diff); double mag1 = *std::max_element(r1.begin(), r1.end(), compare_mag); double mag2 = *std::max_element(r2.begin(), r2.end(), compare_mag); double mag = std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits::min()}); return std::sqrt(square_difference) / (std::sqrt(n) * mag); } else return std::numeric_limits>::max(); } template bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = nullptr) { double threshold = std::numeric_limits>::epsilon() * tolerance; auto error = rms_range(r1, r2); if(out_error != nullptr) *out_error = error; return error <= threshold; } } // namespace migraph #endif