Unverified Commit dcafb1de authored by aledudek's avatar aledudek Committed by GitHub
Browse files

Generic threshold calculation after merge fixes (#1618)



* Generic threshold calculation add passing num of accums

* Generic threshold - after merge fixes

* Fix cmakelists

---------
Co-authored-by: default avatarAdam Osewski <19374865+aosewski@users.noreply.github.com>
parent 365f39ae
......@@ -24,7 +24,7 @@ namespace ck {
namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int numberOfAccumulations = 1)
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck::f8_t;
using F16 = ck::half_t;
......@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
}
else
{
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1)
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck::f8_t;
using F16 = ck::half_t;
......@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
else
{
acc_error =
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
......
......@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
auto number_of_accumulations = 1;
static_assert(
ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX,
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!");
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{
for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i)
{
number_of_accumulations *= kernel_params.window_spatial_lengths.at(i);
}
}
auto absolute_error_threshold = 1.0;
switch(in_params.init_method)
{
......@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
absolute_error_threshold =
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
absolute_error_threshold);
absolute_error_threshold, number_of_accumulations);
auto relative_error_threshold =
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(
number_of_accumulations);
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment