Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
......@@ -23,195 +23,220 @@
namespace ck {
namespace utils {
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-6)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
struct CorrectnessValidator {
public:
CorrectnessValidator(bool pass_if_no_instance = false) : pass_if_no_instance_{pass_if_no_instance},
found_supporting_instance_{false},
correct_results_{true} {
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
bool is_success() {
return (pass_if_no_instance_ || found_supporting_instance_) && correct_results_;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-6)
{
const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
found_supporting_instance_ = true;
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
const double o = *std::next(std::begin(out), i);
const double r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
res = false;
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
found_supporting_instance_ = true;
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
bool res{true};
int err_count = 0;
double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
res = false;
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
if(out.size() != ref.size())
template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
found_supporting_instance_ = true;
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
}
res = false;
}
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
}
bool res{true};
int err_count = 0;
int64_t err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{
const int64_t o = *std::next(std::begin(out), i);
const int64_t r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
found_supporting_instance_ = true;
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
if(err > atol)
bool res{true};
int err_count = 0;
int64_t err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
const int64_t o = *std::next(std::begin(out), i);
const int64_t r = *std::next(std::begin(ref), i);
err = std::abs(o - r);
if(err > atol)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
}
res = false;
}
res = false;
}
if(!res)
{
std::cerr << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
}
if(!res)
{
std::cerr << "max err: " << max_err << std::endl;
}
return res;
private:
bool pass_if_no_instance_;
bool found_supporting_instance_;
bool correct_results_;
}
} // namespace utils
......
......@@ -76,7 +76,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
using RefAcc0DataType = float;
using RefAcc1DataType = float;
bool pass = true;
ck::utils::CorrectnessValidator validator;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
......@@ -331,7 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
{
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
validator.check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
if(do_log)
{
......@@ -353,7 +353,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -83,7 +83,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp,
CElementOp>;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
......@@ -355,7 +355,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
atol = 1e-2;
}
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result,
validator.check_err(c_gs_ms_os_device_result,
c_gs_ms_os_host_result,
"Error: Incorrect results!",
rtol,
......@@ -388,7 +388,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -78,7 +78,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
B1ElementOp,
CElementOp>;
bool pass = true;
ck::utils::CorrectnessValidator validator;
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
......@@ -284,7 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
{
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
validator.check_err(c_g_m_o_device_result, c_g_m_o_host_result);
if(do_log)
{
......@@ -312,7 +312,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -49,7 +49,7 @@ bool profile_batched_gemm_impl(int do_verification,
int StrideC,
int BatchCount)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
......@@ -234,7 +234,7 @@ bool profile_batched_gemm_impl(int do_verification,
{
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
validator.check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(do_log)
{
......@@ -257,7 +257,7 @@ bool profile_batched_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -72,7 +72,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
int StrideC,
int BatchCount)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
......@@ -316,13 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result);
bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result);
pass = pass && (c_error == true);
pass = pass && (d0_error == true);
pass = pass && (d1_error == true);
validator.check_err(c_g_m_n_device_result, c_g_m_n_host_result);
validator.check_err(d0_g_m_device_result, d0_g_m_host_result);
validator.check_err(d1_g_m_device_result, d1_g_m_host_result);
if(do_log)
{
......@@ -355,7 +351,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -86,7 +86,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
B1ElementOp,
CElementOp>;
bool pass = true;
ck::utils::CorrectnessValidator validator;
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
......@@ -312,7 +312,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
{
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result);
validator.check_err(c_g_m_o_device_result, c_g_m_o_host_result);
if(do_log)
{
......@@ -340,7 +340,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -81,7 +81,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp,
CElementOp>;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
......@@ -327,7 +327,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
atol = 1e-2;
}
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result,
validator.check_err(c_gs_ms_os_device_result,
c_gs_ms_os_host_result,
"Error: Incorrect results!",
rtol,
......@@ -360,7 +360,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -265,7 +265,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
}
int num_kernel = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -340,20 +340,17 @@ bool profile_batchnorm_backward_impl(bool do_verification,
if(do_verification)
{
using ck::utils::check_err;
bool single_pass = true;
using ck::utils;
dx_dev.FromDevice(dx.mData.data());
dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data());
// clang-format off
single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
validator.check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
validator.check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
validator.check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
// clang-format on
pass = pass && single_pass;
};
if(do_dumpout)
......@@ -383,7 +380,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
return false;
}
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -259,7 +259,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
}
int num_kernel = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -336,15 +336,15 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(do_verification)
{
using ck::utils::check_err;
using ck::utils;
bool single_pass;
y_dev.FromDevice(y.mData.data());
if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>)
single_pass = check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
else
single_pass = check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
if(updateMovingAverage)
{
......@@ -352,8 +352,8 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
// clang-format off
single_pass = single_pass && check_err(resultRunningMean.mData, resultRunningMean_ref.mData, "average mean results", 1.5e-5, 1.5e-5);
single_pass = single_pass && check_err(resultRunningVariance.mData, resultRunningVariance_ref.mData, "average variance results", 1e-5, 1e-5);
check_err(resultRunningMean.mData, resultRunningMean_ref.mData, "average mean results", 1.5e-5, 1.5e-5);
check_err(resultRunningVariance.mData, resultRunningVariance_ref.mData, "average variance results", 1e-5, 1e-5);
// clang-format on
};
......@@ -363,12 +363,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
// clang-format off
single_pass = single_pass && check_err(resultSaveMean.mData, resultSaveMean_ref.mData, "mean results", 3e-5, 3e-5);
single_pass = single_pass && check_err(resultSaveInvVariance.mData, resultSaveInvVariance_ref.mData, "inv-variance results", 7e-5, 7e-5);
check_err(resultSaveMean.mData, resultSaveMean_ref.mData, "mean results", 3e-5, 3e-5);
check_err(resultSaveInvVariance.mData, resultSaveInvVariance_ref.mData, "inv-variance results", 7e-5, 7e-5);
// clang-format on
};
pass = pass && single_pass;
};
if(do_dumpout)
......@@ -405,7 +404,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
return false;
}
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -231,7 +231,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
}
int num_kernel = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -291,17 +291,15 @@ bool profile_batchnorm_infer_impl(int do_verification,
if(do_verification)
{
using ck::utils::check_err;
bool single_pass;
using ck::utils;
y_dev.FromDevice(y.mData.data());
if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>)
single_pass = check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
else
single_pass = check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
pass = pass && single_pass;
};
if(do_dumpout)
......@@ -328,7 +326,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
return false;
}
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& StridesE,
const std::vector<ck::index_t>& StridesD)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23,
......@@ -274,7 +274,7 @@ int profile_contraction_impl(ck::index_t do_verification,
float threshold =
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
pass = pass & ck::utils::check_err(e_m_n_device_result,
validator.check_err(e_m_n_device_result,
e_m_n_host_result,
"Error: incorrect results!",
threshold,
......@@ -338,7 +338,7 @@ int profile_contraction_impl(ck::index_t do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -153,7 +153,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
float best_gb_per_sec = 0;
// profile device Conv instances
bool pass = true;
ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs)
{
......@@ -209,7 +209,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
{
in_device_buf.FromDevice(input_device_result.mData.data());
pass = pass & ck::utils::check_err(input_device_result, input_host_result);
validator.check_err(input_device_result, input_host_result);
if(do_log)
{
......@@ -241,7 +241,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -193,6 +193,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device Conv instances
for(auto& op_ptr : op_ptrs)
{
......@@ -251,7 +253,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
{
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.is_success();
if(do_log)
{
......
......@@ -183,6 +183,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device Conv instances
for(auto& op_ptr : op_ptrs)
{
......@@ -239,7 +241,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.is_success();
if(do_log)
{
......
......@@ -135,7 +135,7 @@ bool profile_conv_fwd_impl(int do_verification,
float best_gb_per_sec = 0;
// profile device op instances
bool pass = true;
ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs)
{
......@@ -191,7 +191,7 @@ bool profile_conv_fwd_impl(int do_verification,
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, host_output);
validator.check_err(device_output, host_output);
if(do_log)
{
......@@ -214,7 +214,7 @@ bool profile_conv_fwd_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -164,6 +164,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
}
int num_kernel = 0;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -221,8 +222,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
{
y_dev.FromDevice(y.mData.data());
bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
validator.check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log)
{
......@@ -232,7 +232,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
LogRangeAsType<float>(std::cout << "y : ", y.mData, ",") << std::endl;
}
if(!pass)
if(!validator.is_success())
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
......@@ -253,13 +253,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is tested" << std::endl;
return false;
}
return true;
return validator.is_success();
}
} // namespace profiler
......
......@@ -165,7 +165,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
......@@ -223,7 +223,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
validator.check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
......@@ -235,7 +235,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -156,7 +156,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
......@@ -213,7 +213,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
validator.check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
......@@ -225,7 +225,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -165,7 +165,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
......@@ -222,8 +222,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
if(do_verification)
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
validator.check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
......@@ -235,7 +234,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment