"docs/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "70d9faf749648435a52560dbf3770c628fd61fb8"
Commit 2b1f1aad authored by Paul's avatar Paul
Browse files

Add verify method

parent 0005506c
#ifndef RTG_GUARD_VERIFY_HPP
#define RTG_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
#include <functional>
#include <iostream>
#include <numeric>
namespace test {
// Compute the value of a range
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
struct sum_fn
{
template <class T, class U>
auto operator()(T x, U y) const { return x + y; }
};
static constexpr sum_fn sum{};
struct max_fn
{
template <class T>
static T id(T x)
{
return x;
}
template <class T, class U>
auto operator()(T x, U y) const { return x > y ? x : y; }
};
static constexpr max_fn max{};
namespace abs_diff_detail {
using std::fabs;
struct fn
{
template <class T, class U>
auto operator()(T x, U y) const { return fabs(x - y); }
};
} // namespace abs_diff_detail
static constexpr abs_diff_detail::fn abs_diff{};
struct not_finite_fn
{
template <class T>
bool operator()(T x) const
{
using std::isfinite;
return not isfinite(x);
}
};
static constexpr not_finite_fn not_finite{};
template <class T, class U>
T as(T, U x)
{
return x;
}
struct compare_mag_fn
{
template <class T, class U>
bool operator()(T x, U y) const
{
using std::fabs;
return fabs(x) < fabs(y);
}
};
static constexpr compare_mag_fn compare_mag{};
struct square_diff_fn
{
template <class T, class U>
double operator()(T x, U y) const
{
return (x - y) * (x - y);
}
};
static constexpr square_diff_fn square_diff{};
template <class R1>
bool range_empty(R1&& r1)
{
return r1.begin() == r1.end();
}
template <class R1>
auto range_distance(R1&& r1)
{ return std::distance(r1.begin(), r1.end()); }
template <class R1>
bool range_zero(R1&& r1)
{
return std::all_of(r1.begin(), r1.end(), [](auto x) { return x == 0; });
}
template <class R1, class R2, class T, class Reducer, class Product>
T range_product(R1&& r1, R2&& r2, T state, Reducer r, Product p)
{
return std::inner_product(r1.begin(), r1.end(), r2.begin(), state, r, p);
}
template <class R1, class R2, class Compare>
std::size_t mismatch_idx(R1&& r1, R2&& r2, Compare compare)
{
auto p = std::mismatch(r1.begin(), r1.end(), r2.begin(), compare);
return std::distance(r1.begin(), p.first);
}
template <class R1, class Predicate>
long find_idx(R1&& r1, Predicate p)
{
auto it = std::find_if(r1.begin(), r1.end(), p);
if(it == r1.end())
return -1;
else
return std::distance(r1.begin(), it);
}
template <class R1, class R2>
double max_diff(R1&& r1, R2&& r2)
{
return range_product(r1, r2, 0.0, max, abs_diff);
}
template <class R1, class R2, class T>
std::size_t mismatch_diff(R1&& r1, R2&& r2, T diff)
{
return mismatch_idx(
r1,
r2,
[&](auto x, auto y) {
auto d = abs_diff(x, y);
return !(d > diff && d < diff);
});
}
template <class R1, class R2>
double rms_range(R1&& r1, R2&& r2)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
double square_difference = range_product(r1, r2, 0.0, sum_fn{}, square_diff);
double mag1 = *std::max_element(r1.begin(), r1.end(), compare_mag);
double mag2 = *std::max_element(r2.begin(), r2.end(), compare_mag);
double mag =
std::max({std::fabs(mag1), std::fabs(mag2), std::numeric_limits<double>::min()});
return std::sqrt(square_difference) / (std::sqrt(n) * mag);
}
else
return std::numeric_limits<range_value<R1>>::max();
}
template<class R1, class R2>
bool verify_range(R1&& r1, R2&& r2, double tolerance=80)
{
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance;
auto error = rms_range(r1, r2);
return error <= threshold;
}
} // namespace test
#endif
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <random> #include <random>
#include "test.hpp" #include "test.hpp"
#include "verify.hpp"
using hip_ptr = RTG_MANAGE_PTR(void, hipFree); using hip_ptr = RTG_MANAGE_PTR(void, hipFree);
using miopen_handle = RTG_MANAGE_PTR(miopenHandle_t, miopenDestroy); using miopen_handle = RTG_MANAGE_PTR(miopenHandle_t, miopenDestroy);
...@@ -116,9 +117,7 @@ void test1() ...@@ -116,9 +117,7 @@ void test1()
{ {
auto x = cpu(); auto x = cpu();
auto y = gpu(); auto y = gpu();
// TODO: Use expect EXPECT(test::verify_range(x, y));
if(x == y)
std::cout << "FAILED" << std::endl;
} }
int main() { test1(); } int main() { test1(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment