Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
...@@ -23,195 +23,220 @@ ...@@ -23,195 +23,220 @@
namespace ck { namespace ck {
namespace utils { namespace utils {
template <typename Range, typename RefRange> struct CorrectnessValidator {
typename std::enable_if< public:
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && CorrectnessValidator(bool pass_if_no_instance = false) : pass_if_no_instance_{pass_if_no_instance},
std::is_floating_point_v<ranges::range_value_t<Range>> && found_supporting_instance_{false},
!std::is_same_v<ranges::range_value_t<Range>, half_t>, correct_results_{true} {
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-6)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
} }
bool res{true}; bool is_success() {
int err_count = 0; return (pass_if_no_instance_ || found_supporting_instance_) && correct_results_;
double err = 0; }
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i) template <typename Range, typename RefRange>
typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_floating_point_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-6)
{ {
const double o = *std::next(std::begin(out), i); found_supporting_instance_ = true;
const double r = *std::next(std::begin(ref), i); if(out.size() != ref.size())
err = std::abs(o - r); {
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<double>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{ {
max_err = err > max_err ? err : max_err; const double o = *std::next(std::begin(out), i);
err_count++; const double r = *std::next(std::begin(ref), i);
if(err_count < 5) err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i max_err = err > max_err ? err : max_err;
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
} }
res = false;
} }
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
} }
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
typename std::enable_if< typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, bhalf_t>, std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
bool>::type bool>::type
check_err(const Range& out, check_err(const Range& out,
const RefRange& ref, const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
{
if(out.size() != ref.size())
{ {
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() found_supporting_instance_ = true;
<< std::endl; if(out.size() != ref.size())
return false; {
} std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
// TODO: This is a hack. We should have proper specialization for bhalf_t data type. // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
double max_err = std::numeric_limits<float>::min(); double max_err = std::numeric_limits<float>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; const double o = type_convert<float>(*std::next(std::begin(out), i));
err_count++; const double r = type_convert<float>(*std::next(std::begin(ref), i));
if(err_count < 5) err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i max_err = err > max_err ? err : max_err;
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
} }
res = false;
} }
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
} }
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
typename std::enable_if< typename std::enable_if<
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, half_t>, std::is_same_v<ranges::range_value_t<Range>, half_t>,
bool>::type bool>::type
check_err(const Range& out, check_err(const Range& out,
const RefRange& ref, const RefRange& ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
{
if(out.size() != ref.size())
{ {
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() found_supporting_instance_ = true;
<< std::endl; if(out.size() != ref.size())
return false; {
} std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min(); double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{
const double o = type_convert<float>(*std::next(std::begin(out), i));
const double r = type_convert<float>(*std::next(std::begin(ref), i));
err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
max_err = err > max_err ? err : max_err; const double o = type_convert<float>(*std::next(std::begin(out), i));
err_count++; const double r = type_convert<float>(*std::next(std::begin(ref), i));
if(err_count < 5) err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i max_err = err > max_err ? err : max_err;
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; err_count++;
if(err_count < 5)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
}
res = false;
} }
res = false;
} }
} if(!res)
if(!res) {
{ std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; }
} correct_results_ = correct_results_ && res;
return res; return res;
}
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
} }
bool res{true}; template <typename Range, typename RefRange>
int err_count = 0; std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
int64_t err = 0; std::is_integral_v<ranges::range_value_t<Range>> &&
int64_t max_err = std::numeric_limits<int64_t>::min(); !std::is_same_v<ranges::range_value_t<Range>, bhalf_t>)
for(std::size_t i = 0; i < ref.size(); ++i) #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif
,
bool>
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double atol = 0)
{ {
const int64_t o = *std::next(std::begin(out), i); found_supporting_instance_ = true;
const int64_t r = *std::next(std::begin(ref), i); if(out.size() != ref.size())
err = std::abs(o - r); {
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl;
return false;
}
if(err > atol) bool res{true};
int err_count = 0;
int64_t err = 0;
int64_t max_err = std::numeric_limits<int64_t>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{ {
max_err = err > max_err ? err : max_err; const int64_t o = *std::next(std::begin(out), i);
err_count++; const int64_t r = *std::next(std::begin(ref), i);
if(err_count < 5) err = std::abs(o - r);
if(err > atol)
{ {
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r max_err = err > max_err ? err : max_err;
<< std::endl; err_count++;
if(err_count < 5)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
}
res = false;
} }
res = false;
} }
if(!res)
{
std::cerr << "max err: " << max_err << std::endl;
}
correct_results_ = correct_results_ && res;
return res;
} }
if(!res) private:
{ bool pass_if_no_instance_;
std::cerr << "max err: " << max_err << std::endl; bool found_supporting_instance_;
} bool correct_results_;
return res;
} }
} // namespace utils } // namespace utils
......
...@@ -76,7 +76,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -76,7 +76,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
using RefAcc0DataType = float; using RefAcc0DataType = float;
using RefAcc1DataType = float; using RefAcc1DataType = float;
bool pass = true; ck::utils::CorrectnessValidator validator;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M; const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
...@@ -331,7 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -331,7 +331,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
{ {
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result); validator.check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
if(do_log) if(do_log)
{ {
...@@ -353,7 +353,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, ...@@ -353,7 +353,7 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -83,7 +83,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -83,7 +83,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
bool pass = true; ck::utils::CorrectnessValidator validator;
// A layout [G0, M, G1, K] // A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -355,7 +355,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -355,7 +355,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
atol = 1e-2; atol = 1e-2;
} }
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result, validator.check_err(c_gs_ms_os_device_result,
c_gs_ms_os_host_result, c_gs_ms_os_host_result,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol, rtol,
...@@ -388,7 +388,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -388,7 +388,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -78,7 +78,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -78,7 +78,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
bool pass = true; ck::utils::CorrectnessValidator validator;
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
...@@ -284,7 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -284,7 +284,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
{ {
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); validator.check_err(c_g_m_o_device_result, c_g_m_o_host_result);
if(do_log) if(do_log)
{ {
...@@ -312,7 +312,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -312,7 +312,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -49,7 +49,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -49,7 +49,7 @@ bool profile_batched_gemm_impl(int do_verification,
int StrideC, int StrideC,
int BatchCount) int BatchCount)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
...@@ -234,7 +234,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -234,7 +234,7 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); validator.check_err(c_g_m_n_device_result, c_g_m_n_host_result);
if(do_log) if(do_log)
{ {
...@@ -257,7 +257,7 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -257,7 +257,7 @@ bool profile_batched_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -72,7 +72,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -72,7 +72,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
int StrideC, int StrideC,
int BatchCount) int BatchCount)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
...@@ -316,13 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -316,13 +316,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); validator.check_err(c_g_m_n_device_result, c_g_m_n_host_result);
bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result); validator.check_err(d0_g_m_device_result, d0_g_m_host_result);
bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result); validator.check_err(d1_g_m_device_result, d1_g_m_host_result);
pass = pass && (c_error == true);
pass = pass && (d0_error == true);
pass = pass && (d1_error == true);
if(do_log) if(do_log)
{ {
...@@ -355,7 +351,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -355,7 +351,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -86,7 +86,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -86,7 +86,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
bool pass = true; ck::utils::CorrectnessValidator validator;
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
...@@ -312,7 +312,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -312,7 +312,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
{ {
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); validator.check_err(c_g_m_o_device_result, c_g_m_o_host_result);
if(do_log) if(do_log)
{ {
...@@ -340,7 +340,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -340,7 +340,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -81,7 +81,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -81,7 +81,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
bool pass = true; ck::utils::CorrectnessValidator validator;
// A layout [G0, M, G1, K] // A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -327,7 +327,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -327,7 +327,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
atol = 1e-2; atol = 1e-2;
} }
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result, validator.check_err(c_gs_ms_os_device_result,
c_gs_ms_os_host_result, c_gs_ms_os_host_result,
"Error: Incorrect results!", "Error: Incorrect results!",
rtol, rtol,
...@@ -360,7 +360,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -360,7 +360,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -265,7 +265,7 @@ bool profile_batchnorm_backward_impl(bool do_verification, ...@@ -265,7 +265,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
} }
int num_kernel = 0; int num_kernel = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -340,20 +340,17 @@ bool profile_batchnorm_backward_impl(bool do_verification, ...@@ -340,20 +340,17 @@ bool profile_batchnorm_backward_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
using ck::utils::check_err; using ck::utils;
bool single_pass = true;
dx_dev.FromDevice(dx.mData.data()); dx_dev.FromDevice(dx.mData.data());
dscale_dev.FromDevice(dscale.data()); dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data()); dbias_dev.FromDevice(dbias.data());
// clang-format off // clang-format off
single_pass = single_pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4); validator.check_err(dx.mData, dx_ref.mData, "dx result:", 5e-4, 5e-4);
single_pass = single_pass && ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3); validator.check_err(dscale.mData, dscale_ref.mData, "dScale result:", 3e-3, 3e-3);
single_pass = single_pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3); validator.check_err(dbias.mData, dbias_ref.mData, "dBias result:", 3e-3, 3e-3);
// clang-format on // clang-format on
pass = pass && single_pass;
}; };
if(do_dumpout) if(do_dumpout)
...@@ -383,7 +380,7 @@ bool profile_batchnorm_backward_impl(bool do_verification, ...@@ -383,7 +380,7 @@ bool profile_batchnorm_backward_impl(bool do_verification,
return false; return false;
} }
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -259,7 +259,7 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -259,7 +259,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
} }
int num_kernel = 0; int num_kernel = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -336,15 +336,15 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -336,15 +336,15 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ck::utils::check_err; using ck::utils;
bool single_pass; bool single_pass;
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>) if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>)
single_pass = check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2); check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
else else
single_pass = check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3); check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
if(updateMovingAverage) if(updateMovingAverage)
{ {
...@@ -352,8 +352,8 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -352,8 +352,8 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
// clang-format off // clang-format off
single_pass = single_pass && check_err(resultRunningMean.mData, resultRunningMean_ref.mData, "average mean results", 1.5e-5, 1.5e-5); check_err(resultRunningMean.mData, resultRunningMean_ref.mData, "average mean results", 1.5e-5, 1.5e-5);
single_pass = single_pass && check_err(resultRunningVariance.mData, resultRunningVariance_ref.mData, "average variance results", 1e-5, 1e-5); check_err(resultRunningVariance.mData, resultRunningVariance_ref.mData, "average variance results", 1e-5, 1e-5);
// clang-format on // clang-format on
}; };
...@@ -363,12 +363,11 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -363,12 +363,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
// clang-format off // clang-format off
single_pass = single_pass && check_err(resultSaveMean.mData, resultSaveMean_ref.mData, "mean results", 3e-5, 3e-5); check_err(resultSaveMean.mData, resultSaveMean_ref.mData, "mean results", 3e-5, 3e-5);
single_pass = single_pass && check_err(resultSaveInvVariance.mData, resultSaveInvVariance_ref.mData, "inv-variance results", 7e-5, 7e-5); check_err(resultSaveInvVariance.mData, resultSaveInvVariance_ref.mData, "inv-variance results", 7e-5, 7e-5);
// clang-format on // clang-format on
}; };
pass = pass && single_pass;
}; };
if(do_dumpout) if(do_dumpout)
...@@ -405,7 +404,7 @@ bool profile_batchnorm_forward_impl(int do_verification, ...@@ -405,7 +404,7 @@ bool profile_batchnorm_forward_impl(int do_verification,
return false; return false;
} }
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -231,7 +231,7 @@ bool profile_batchnorm_infer_impl(int do_verification, ...@@ -231,7 +231,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
} }
int num_kernel = 0; int num_kernel = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -291,17 +291,15 @@ bool profile_batchnorm_infer_impl(int do_verification, ...@@ -291,17 +291,15 @@ bool profile_batchnorm_infer_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
using ck::utils::check_err; using ck::utils;
bool single_pass;
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>) if constexpr(ck::is_same_v<YDataType, ck::bhalf_t>)
single_pass = check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2); check_err(y.mData, y_ref.mData, "y results", 1e-2, 1e-2);
else else
single_pass = check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3); check_err(y.mData, y_ref.mData, "y results", 4e-3, 4e-3);
pass = pass && single_pass;
}; };
if(do_dumpout) if(do_dumpout)
...@@ -328,7 +326,7 @@ bool profile_batchnorm_infer_impl(int do_verification, ...@@ -328,7 +326,7 @@ bool profile_batchnorm_infer_impl(int do_verification,
return false; return false;
} }
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -50,7 +50,7 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& StridesE, const std::vector<ck::index_t>& StridesE,
const std::vector<ck::index_t>& StridesD) const std::vector<ck::index_t>& StridesD)
{ {
bool pass = true; ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor = [](const std::vector<ck::index_t>& dims01, auto f_host_tensor_descriptor = [](const std::vector<ck::index_t>& dims01,
const std::vector<ck::index_t>& dims23, const std::vector<ck::index_t>& dims23,
...@@ -274,7 +274,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -274,7 +274,7 @@ int profile_contraction_impl(ck::index_t do_verification,
float threshold = float threshold =
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon(); static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
pass = pass & ck::utils::check_err(e_m_n_device_result, validator.check_err(e_m_n_device_result,
e_m_n_host_result, e_m_n_host_result,
"Error: incorrect results!", "Error: incorrect results!",
threshold, threshold,
...@@ -338,7 +338,7 @@ int profile_contraction_impl(ck::index_t do_verification, ...@@ -338,7 +338,7 @@ int profile_contraction_impl(ck::index_t do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl; << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -153,7 +153,7 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -153,7 +153,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device Conv instances // profile device Conv instances
bool pass = true; ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -209,7 +209,7 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -209,7 +209,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
{ {
in_device_buf.FromDevice(input_device_result.mData.data()); in_device_buf.FromDevice(input_device_result.mData.data());
pass = pass & ck::utils::check_err(input_device_result, input_host_result); validator.check_err(input_device_result, input_host_result);
if(do_log) if(do_log)
{ {
...@@ -241,7 +241,7 @@ bool profile_conv_bwd_data_impl(int do_verification, ...@@ -241,7 +241,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -193,6 +193,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -193,6 +193,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device Conv instances // profile device Conv instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -251,7 +253,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -251,7 +253,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result); validator.check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.is_success();
if(do_log) if(do_log)
{ {
......
...@@ -183,6 +183,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -183,6 +183,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device Conv instances // profile device Conv instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -239,7 +241,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -239,7 +241,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{ {
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
ck::utils::check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result); validator.check_err(out_n_k_ho_wo_device_result, out_n_k_ho_wo_host_result);
validator.is_success();
if(do_log) if(do_log)
{ {
......
...@@ -135,7 +135,7 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -135,7 +135,7 @@ bool profile_conv_fwd_impl(int do_verification,
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device op instances // profile device op instances
bool pass = true; ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
...@@ -191,7 +191,7 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -191,7 +191,7 @@ bool profile_conv_fwd_impl(int do_verification,
{ {
out_device_buf.FromDevice(device_output.mData.data()); out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, host_output); validator.check_err(device_output, host_output);
if(do_log) if(do_log)
{ {
...@@ -214,7 +214,7 @@ bool profile_conv_fwd_impl(int do_verification, ...@@ -214,7 +214,7 @@ bool profile_conv_fwd_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -164,6 +164,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -164,6 +164,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
} }
int num_kernel = 0; int num_kernel = 0;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs) for(auto& inst_ptr : instance_ptrs)
{ {
...@@ -221,8 +222,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -221,8 +222,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
{ {
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = validator.check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log) if(do_log)
{ {
...@@ -232,7 +232,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -232,7 +232,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
LogRangeAsType<float>(std::cout << "y : ", y.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "y : ", y.mData, ",") << std::endl;
} }
if(!pass) if(!validator.is_success())
{ {
std::cout << inst_ptr->GetTypeString() << " failed verification: "; std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
...@@ -253,13 +253,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -253,13 +253,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
<< best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
} }
if(num_kernel == 0) return validator.is_success();
{
std::cout << "Error: No kernel is tested" << std::endl;
return false;
}
return true;
} }
} // namespace profiler } // namespace profiler
......
...@@ -165,7 +165,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -165,7 +165,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
// profile device operation instances // profile device operation instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -223,7 +223,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -223,7 +223,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); validator.check_err(e_m_n_device_result, e_m_n_host_result);
} }
} }
else else
...@@ -235,7 +235,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -235,7 +235,7 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -156,7 +156,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification, ...@@ -156,7 +156,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
// profile device operation instances // profile device operation instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -213,7 +213,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification, ...@@ -213,7 +213,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); validator.check_err(e_m_n_device_result, e_m_n_host_result);
} }
} }
else else
...@@ -225,7 +225,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification, ...@@ -225,7 +225,7 @@ bool profile_gemm_add_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
...@@ -165,7 +165,7 @@ bool profile_gemm_add_multiply_impl(int do_verification, ...@@ -165,7 +165,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; ck::utils::CorrectnessValidator validator;
// profile device operation instances // profile device operation instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
...@@ -222,8 +222,7 @@ bool profile_gemm_add_multiply_impl(int do_verification, ...@@ -222,8 +222,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
validator.check_err(e_m_n_device_result, e_m_n_host_result);
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
} }
} }
else else
...@@ -235,7 +234,7 @@ bool profile_gemm_add_multiply_impl(int do_verification, ...@@ -235,7 +234,7 @@ bool profile_gemm_add_multiply_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return validator.is_success();
} }
} // namespace profiler } // namespace profiler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment