Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
...@@ -242,7 +242,7 @@ int main(int argc, char* argv[]) ...@@ -242,7 +242,7 @@ int main(int argc, char* argv[])
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
} }
#endif #endif
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
} }
return 0; return 0;
......
...@@ -236,15 +236,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -236,15 +236,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>(); c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); validator.check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); validator.check_err(c_m_n_device_result, c_m_n_host_result);
#endif #endif
} }
return true; return validator.is_success();
} }
bool run_gemm_example(int argc, char* argv[]) bool run_gemm_example(int argc, char* argv[])
......
...@@ -297,7 +297,8 @@ int main(int argc, char* argv[]) ...@@ -297,7 +297,8 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
} }
return 0; return 0;
......
...@@ -299,7 +299,8 @@ int main(int argc, char* argv[]) ...@@ -299,7 +299,8 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
} }
return 0; return 0;
......
...@@ -276,7 +276,8 @@ int main(int argc, char* argv[]) ...@@ -276,7 +276,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
} }
return 0; return 0;
......
...@@ -147,13 +147,13 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -147,13 +147,13 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_m_n_device_result_converted(e_m_n_device_result); const Tensor<EDataType> e_m_n_device_result_converted(e_m_n_device_result);
return ck::utils::check_err(e_m_n_device_result_converted, e_m_n_host_result); validator.check_err(e_m_n_device_result_converted, e_m_n_host_result);
#else #else
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); validator.check_err(e_m_n_device_result, e_m_n_host_result);
#endif #endif
} }
return true; return validator.is_success();
} }
bool run_gemm_add_add_fastgelu_example(int argc, char* argv[]) bool run_gemm_add_add_fastgelu_example(int argc, char* argv[])
......
...@@ -164,8 +164,9 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -164,8 +164,9 @@ bool run_grouped_conv_fwd(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( validator.check_err(
out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f);
return validator.is_success();
} }
return true; return true;
......
...@@ -188,8 +188,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -188,8 +188,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( validator.check_err(
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
return validator.is_success();
} }
return true; return true;
......
...@@ -273,13 +273,14 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, ...@@ -273,13 +273,14 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
conv_output_device_buf.FromDevice(conv_output_device.mData.data()); conv_output_device_buf.FromDevice(conv_output_device.mData.data());
r0_device_buf.FromDevice(r0_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data());
return ck::utils::check_err(conv_output_device, validator.check_err(conv_output_device,
conv_output_host, conv_output_host,
"Error: incorrect results! (Matrix E)", "Error: incorrect results! (Matrix E)",
1e-5f, 1e-5f,
1e-4f) && 1e-4f);
ck::utils::check_err( validator.check_err(
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f);
return validator.is_success();
} }
return true; return true;
......
...@@ -326,8 +326,6 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -326,8 +326,6 @@ int reduce_blockwise_impl(bool do_verification,
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
<< std::endl; << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...@@ -343,14 +341,14 @@ int reduce_blockwise_impl(bool do_verification, ...@@ -343,14 +341,14 @@ int reduce_blockwise_impl(bool do_verification,
#endif #endif
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out, out_ref); validator.check_err(out, out_ref);
if(OutputIndex) if(OutputIndex)
{ {
out_index_dev.FromDevice(out_indices.mData.data()); out_index_dev.FromDevice(out_indices.mData.data());
pass = pass && ck::utils::check_err(out_indices, out_indices_ref); validator.check_err(out_indices, out_indices_ref);
}; };
}; };
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -307,13 +307,11 @@ int main(int argc, char* argv[]) ...@@ -307,13 +307,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, "
<< reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl; << reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl;
bool pass = true;
if(do_verify) if(do_verify)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out, out_ref); validator.check_err(out, out_ref);
}; };
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -244,8 +244,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification, ...@@ -244,8 +244,8 @@ int reduce_multiblock_atomic_add_impl(bool do_verification,
if(do_verification) if(do_verification)
{ {
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
pass = pass && ck::utils::check_err(out, out_ref); validator.check_err(out, out_ref);
}; };
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -182,16 +182,14 @@ bool pool_test(bool do_verification, ...@@ -182,16 +182,14 @@ bool pool_test(bool do_verification,
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); validator.check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host);
if constexpr(OutputIndex) if constexpr(OutputIndex)
{ {
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
validator.check_err(out_indices_n_c_ho_wo_device, out_indices_n_c_ho_wo_host);
pass = pass &&
ck::utils::check_err(out_indices_n_c_ho_wo_device, out_indices_n_c_ho_wo_host);
}; };
} }
return (pass); return validator.is_success();
}; };
...@@ -197,7 +197,8 @@ int main() ...@@ -197,7 +197,8 @@ int main()
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -228,7 +228,8 @@ int main() ...@@ -228,7 +228,8 @@ int main()
} }
} }
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -200,7 +200,8 @@ int main() ...@@ -200,7 +200,8 @@ int main()
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; validator.check_err(e_m_n_device_result, e_m_n_host_result);
return validator.is_success();
} }
return 0; return 0;
......
...@@ -209,10 +209,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -209,10 +209,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> c_device_result_converted(c_device_tensors[i]); const Tensor<EDataType> c_device_result_converted(c_device_tensors[i]);
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); validator.check_err(c_device_result_converted, c_host_tensors[i]);
#else #else
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); validator.check_err(c_device_tensors[i], c_host_tensors[i]);
#endif #endif
} }
} }
...@@ -227,7 +227,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -227,7 +227,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
<< " GB/s, " << gemm.GetTypeString() << std::endl; << " GB/s, " << gemm.GetTypeString() << std::endl;
} }
return pass; return validator.is_success();
} }
bool run_grouped_gemm_example(int argc, char* argv[]) bool run_grouped_gemm_example(int argc, char* argv[])
......
...@@ -259,9 +259,9 @@ int main() ...@@ -259,9 +259,9 @@ int main()
r0_device_buf.FromDevice(r0_m.mData.data()); r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data()); r1_device_buf.FromDevice(r1_m.mData.data());
pass = ck::utils::check_err(e_m_n, e_m_n_host, "Error: Incorrect results c", 1e-2, 1e-2); validator.check_err(e_m_n, e_m_n_host, "Error: Incorrect results c", 1e-2, 1e-2);
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); validator.check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); validator.check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
} }
bool time_kernel = true; bool time_kernel = true;
...@@ -272,5 +272,5 @@ int main() ...@@ -272,5 +272,5 @@ int main()
ave_time, M, N, K); ave_time, M, N, K);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -261,16 +261,16 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M, ...@@ -261,16 +261,16 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M,
Tensor<EDataType> e_m_n_host_converted(e_m_n_host); Tensor<EDataType> e_m_n_host_converted(e_m_n_host);
pass = ck::utils::check_err( validator.check_err(
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
r0_device_buf.FromDevice(r0_m.mData.data()); r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data()); r1_device_buf.FromDevice(r1_m.mData.data());
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); validator.check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); validator.check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
if(pass) if(validator.is_success())
{ {
std::cout << "Success!" << std::endl; std::cout << "Success!" << std::endl;
} }
...@@ -291,7 +291,7 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M, ...@@ -291,7 +291,7 @@ bool run_gemm_reduce_add_addsquare_xdl(ck::index_t M,
<< " GB/s, " << std::endl; << " GB/s, " << std::endl;
} }
return pass; return validator.is_success();
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
......
...@@ -241,7 +241,7 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, ...@@ -241,7 +241,7 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
if constexpr(std::is_same_v<ADataType, ck::int4_t>) if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{ {
Tensor<EDataType> e_m_n_device_converted(e_m_n); Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted, validator.check_err(e_m_n_device_converted,
e_m_n_host_converted, e_m_n_host_converted,
"Error: Incorrect results c", "Error: Incorrect results c",
1e-2, 1e-2,
...@@ -250,14 +250,14 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, ...@@ -250,14 +250,14 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
else else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{ {
pass = ck::utils::check_err( validator.check_err(
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
} }
r0_device_buf.FromDevice(r0_m.mData.data()); r0_device_buf.FromDevice(r0_m.mData.data());
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); validator.check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
if(pass) if(validator.is_success())
{ {
std::cout << "Success!" << std::endl; std::cout << "Success!" << std::endl;
} }
...@@ -269,7 +269,7 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, ...@@ -269,7 +269,7 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
DumpGemmReduceMaxPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K); DumpGemmReduceMaxPerf<ADataType, BDataType, EDataType, R0DataType>(ave_time, M, N, K);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
template <typename ADataType, template <typename ADataType,
...@@ -455,7 +455,7 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M, ...@@ -455,7 +455,7 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
if constexpr(std::is_same_v<ADataType, ck::int4_t>) if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{ {
Tensor<EDataType> e_m_n_device_converted(e_m_n); Tensor<EDataType> e_m_n_device_converted(e_m_n);
pass = ck::utils::check_err(e_m_n_device_converted, validator.check_err(e_m_n_device_converted,
e_m_n_host_converted, e_m_n_host_converted,
"Error: Incorrect results c", "Error: Incorrect results c",
1e-2, 1e-2,
...@@ -464,17 +464,17 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M, ...@@ -464,17 +464,17 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
else else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{ {
pass = ck::utils::check_err( validator.check_err(
e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2); e_m_n, e_m_n_host_converted, "Error: Incorrect results c", 1e-2, 1e-2);
} }
r0_device_buf.FromDevice(r0_m.mData.data()); r0_device_buf.FromDevice(r0_m.mData.data());
r1_device_buf.FromDevice(r1_m.mData.data()); r1_device_buf.FromDevice(r1_m.mData.data());
pass &= ck::utils::check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2); validator.check_err(r0_m, r0_m_host, "Error: Incorrect results d0", 1e-2, 1e-2);
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2); validator.check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
if(pass) if(validator.is_success())
{ {
std::cout << "Success!" << std::endl; std::cout << "Success!" << std::endl;
} }
...@@ -487,5 +487,5 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M, ...@@ -487,5 +487,5 @@ bool run_gemm_reduce_mean_meansquare_xdl(ck::index_t M,
ave_time, M, N, K); ave_time, M, N, K);
} }
return pass; return validator.is_success();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment