// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type.hpp" #include "ck/host_utility/io.hpp" #include "ck/library/utility/ranges.hpp" namespace ck { namespace utils { template double get_relative_threshold(const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; using I8 = int8_t; using I32 = int32_t; static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { compute_error = std::pow(2, -NumericUtils::mant) * 0.5; } static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { output_error = std::pow(2, -NumericUtils::mant) * 0.5; } double midway_error = std::max(compute_error, output_error); static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } template double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; using I8 = int8_t; using I32 = int32_t; static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { output_error = std::pow(2, expo - NumericUtils::mant) * 0.5; } double midway_error = std::max(compute_error, output_error); static_assert(is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || is_same_v) { return 0; } else { acc_error = std::pow(2, expo - NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_floating_point_v> && !std::is_same_v, half_t>, bool>::type check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, double atol = 3e-6) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = *std::next(std::begin(out), i); const double r = *std::next(std::begin(ref), i); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { const float error_percent = static_cast(err_count) / static_cast(out.size()) * 100.f; std::cerr << "max err: " << max_err; std::cerr << ", number of errors: " << err_count; std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, bhalf_t>, bool>::type check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-1, double atol = 1e-3) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; double err = 0; // TODO: This is a hack. We should have proper specialization for bhalf_t data type. double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { const float error_percent = static_cast(err_count) / static_cast(out.size()) * 100.f; std::cerr << "max err: " << max_err; std::cerr << ", number of errors: " << err_count; std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, half_t>, bool>::type check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; double err = 0; double max_err = NumericLimits>::Min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { const float error_percent = static_cast(err_count) / static_cast(out.size()) * 100.f; std::cerr << "max err: " << max_err; std::cerr << ", number of errors: " << err_count; std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_integral_v> && !std::is_same_v, bhalf_t> && !std::is_same_v, f8_t> && !std::is_same_v, bf8_t>) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 || std::is_same_v, int4_t> #endif , bool> check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double = 0, double atol = 0) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; int64_t err = 0; int64_t max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const int64_t o = *std::next(std::begin(out), i); const int64_t r = *std::next(std::begin(ref), i); err = std::abs(o - r); if(err > atol) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { const float error_percent = static_cast(err_count) / static_cast(out.size()) * 100.f; std::cerr << "max err: " << max_err; std::cerr << ", number of errors: " << err_count; std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, f8_t>), bool> check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << " number of errors: " << err_count << std::endl; } return res; } template std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, bf8_t>), bool> check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-3, double atol = 1e-3) { if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() << std::endl; return false; } bool res{true}; int err_count = 0; double err = 0; double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { const double o = type_convert(*std::next(std::begin(out), i)); const double r = type_convert(*std::next(std::begin(ref), i)); err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; } } if(!res) { std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; } return res; } } // namespace utils } // namespace ck