"...resnet50_tensorflow.git" did not exist on "f2eb1701b83b28741a97e4598e0e5faca43495f0"
Unverified Commit dcafb1de authored by aledudek's avatar aledudek Committed by GitHub
Browse files

Generic threshold calculation after merge fixes (#1618)



* Generic threshold calculation add passing num of accums

* Generic threshold - after merge fixes

* Fix cmakelists

---------
Co-authored-by: default avatarAdam Osewski <19374865+aosewski@users.noreply.github.com>
parent 365f39ae
...@@ -24,7 +24,7 @@ namespace ck { ...@@ -24,7 +24,7 @@ namespace ck {
namespace utils { namespace utils {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int numberOfAccumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1) ...@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
} }
else else
{ {
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations; acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
} }
return std::max(acc_error, midway_error); return std::max(acc_error, midway_error);
} }
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType> template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F8 = ck::f8_t; using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA ...@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
else else
{ {
acc_error = acc_error =
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations; std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
} }
return std::max(acc_error, midway_error); return std::max(acc_error, midway_error);
} }
......
...@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& ...@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
{ {
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
auto number_of_accumulations = 1;
static_assert(
ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX,
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!");
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{
for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i)
{
number_of_accumulations *= kernel_params.window_spatial_lengths.at(i);
}
}
auto absolute_error_threshold = 1.0; auto absolute_error_threshold = 1.0;
switch(in_params.init_method) switch(in_params.init_method)
{ {
...@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& ...@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
absolute_error_threshold = absolute_error_threshold =
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>( ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
absolute_error_threshold); absolute_error_threshold, number_of_accumulations);
auto relative_error_threshold = auto relative_error_threshold =
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(); ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(
number_of_accumulations);
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData, out_n_c_do_ho_wo_host.mData,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment