"src/include/ConstantMatrixDescriptor.hip.hpp" did not exist on "84d9802d30de16795e63a8625098634527c80ae4"
Commit 5a61ffe1 authored by Rostyslav Geyyer's avatar Rostyslav Geyyer
Browse files

Add check_err fp8/bf8 support

parent 135ea647
...@@ -216,13 +216,14 @@ check_err(const Range& out, ...@@ -216,13 +216,14 @@ check_err(const Range& out,
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, f8_t>), (std::is_same_v<ranges::range_value_t<Range>, f8_t> ||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>)),
bool> bool>
check_err(const Range& out, check_err(const Range& out,
const RefRange& ref, const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double = 0, double rtol = 1e-3,
double atol = 0) double atol = 1e-3)
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
...@@ -232,30 +233,30 @@ check_err(const Range& out, ...@@ -232,30 +233,30 @@ check_err(const Range& out,
} }
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
int64_t err = 0; double err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min(); // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const int64_t o = *std::next(std::begin(out), i); const double o = type_convert<float>(*std::next(std::begin(out), i));
const int64_t r = *std::next(std::begin(ref), i); const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r); err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
if(err > atol)
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
} }
} }
if(!res) if(!res)
{ {
std::cerr << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment