Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
...@@ -146,7 +146,8 @@ int run_conv_bwd_data(bool do_verification, ...@@ -146,7 +146,8 @@ int run_conv_bwd_data(bool do_verification,
in_device_buf.FromDevice(in_device.mData.data()); in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_device, in_host) ? 0 : 1; validator.check_err(in_device, in_host);
return !validator.is_success()
} }
return 0; return 0;
......
...@@ -293,19 +293,19 @@ int main(int argc, char* argv[]) ...@@ -293,19 +293,19 @@ int main(int argc, char* argv[])
} }
} }
pass = ck::utils::check_err( validator.check_err(
c_g_m_n_host_result, c_g_m_n_device_result, "Error: Incorrect results c") && c_g_m_n_host_result, c_g_m_n_device_result, "Error: Incorrect results c");
ck::utils::check_err(d0_g_m_device_result, validator.check_err(d0_g_m_device_result,
d0_g_m_host_result, d0_g_m_host_result,
"Error: Incorrect results! D0", "Error: Incorrect results! D0",
1e-4, 1e-4,
1e-5) && 1e-5);
ck::utils::check_err(d1_g_m_device_result, validator.check_err(d1_g_m_device_result,
d1_g_m_host_result, d1_g_m_host_result,
"Error: Incorrect results! D1", "Error: Incorrect results! D1",
1e-3, 1e-3,
1e-5); 1e-5);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -129,8 +129,8 @@ int main() ...@@ -129,8 +129,8 @@ int main()
host_broadcast2D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add, 0>( host_broadcast2D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add, 0>(
host_c_m_n, a_m_n, b_n, M, N, Add{}); host_c_m_n, a_m_n, b_n, M, N, Add{});
pass &= ck::utils::check_err(c_m_n, host_c_m_n, "Error: Incorrect results c", 1e-3, 1e-3); validator.check_err(c_m_n, host_c_m_n, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -112,9 +112,8 @@ int main() ...@@ -112,9 +112,8 @@ int main()
host_broadcast3D_am_bmnk<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>( host_broadcast3D_am_bmnk<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
pass &= validator.check_err(c_m_n_k, host_c_m_n_k, "Error: Incorrect results c", 1e-3, 1e-3);
ck::utils::check_err(c_m_n_k, host_c_m_n_k, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -95,7 +95,6 @@ int main() ...@@ -95,7 +95,6 @@ int main()
std::cout << "Perf: " << ave_time << " ms" << std::endl; std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
c_m_device_buf.FromDevice(c_m.mData.data()); c_m_device_buf.FromDevice(c_m.mData.data());
...@@ -104,8 +103,8 @@ int main() ...@@ -104,8 +103,8 @@ int main()
host_elementwise1D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>( host_elementwise1D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c_m, a_m, b_m, M, Add{}); host_c_m, a_m, b_m, M, Add{});
pass &= ck::utils::check_err(c_m, host_c_m, "Error: Incorrect results c", 1e-3, 1e-3); validator.check_err(c_m, host_c_m, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -104,7 +104,6 @@ int main() ...@@ -104,7 +104,6 @@ int main()
std::cout << "Perf: " << ave_time << " ms" << std::endl; std::cout << "Perf: " << ave_time << " ms" << std::endl;
bool pass = true;
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c.mData.data()); c_device_buf.FromDevice(c.mData.data());
...@@ -113,8 +112,8 @@ int main() ...@@ -113,8 +112,8 @@ int main()
host_elementwise4D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>( host_elementwise4D<Tensor<ABDataType>, Tensor<ABDataType>, Tensor<CDataType>, Add>(
host_c, a, b, nchw, Add{}); host_c, a, b, nchw, Add{});
pass &= ck::utils::check_err(c, host_c, "Error: Incorrect results c", 1e-3, 1e-3); validator.check_err(c, host_c, "Error: Incorrect results c", 1e-3, 1e-3);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -157,7 +157,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -157,7 +157,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
wei_device_buf.FromDevice(wei_device_result.mData.data()); wei_device_buf.FromDevice(wei_device_result.mData.data());
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); return validator.check_err(wei_device_result.mData, wei_host_result.mData);
} }
return true; return true;
......
...@@ -371,7 +371,7 @@ int main() ...@@ -371,7 +371,7 @@ int main()
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n, validator.check_err(layerNorm_m_n,
host_layerNorm_m_n, host_layerNorm_m_n,
"Error: Incorrect results layerNorm_m_n", "Error: Incorrect results layerNorm_m_n",
1e-2, 1e-2,
...@@ -401,5 +401,5 @@ int main() ...@@ -401,5 +401,5 @@ int main()
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -255,9 +255,8 @@ int main() ...@@ -255,9 +255,8 @@ int main()
epsilon); epsilon);
h_device_buf.FromDevice(h_m_n.mData.data()); h_device_buf.FromDevice(h_m_n.mData.data());
pass &= validator.check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
ck::utils::check_err(h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -345,7 +345,7 @@ int main() ...@@ -345,7 +345,7 @@ int main()
N); N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err( validator.check_err(
layerNorm_m_n, host_layerNorm_m_n, "Error: Incorrect results d1", 1e-3, 1e-3); layerNorm_m_n, host_layerNorm_m_n, "Error: Incorrect results d1", 1e-3, 1e-3);
} }
...@@ -370,5 +370,5 @@ int main() ...@@ -370,5 +370,5 @@ int main()
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -274,14 +274,14 @@ int main(int argc, char* argv[]) ...@@ -274,14 +274,14 @@ int main(int argc, char* argv[])
if constexpr(std::is_same<CShuffleDataType, F32>::value) if constexpr(std::is_same<CShuffleDataType, F32>::value)
{ {
pass &= ck::utils::check_err( validator.check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c"); c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c");
} }
else if constexpr(std::is_same<CShuffleDataType, F16>::value) else if constexpr(std::is_same<CShuffleDataType, F16>::value)
{ {
pass &= ck::utils::check_err( validator.check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c", 1e-2, 1e-2); c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results c", 1e-2, 1e-2);
} }
} }
return pass ? 0 : 1; return !validator.is_success();
} }
...@@ -220,12 +220,12 @@ bool run_cgemm_xdl(ck::index_t M, ...@@ -220,12 +220,12 @@ bool run_cgemm_xdl(ck::index_t M,
const Tensor<CDataType> c_m_n_real_device_result_converted(c_m_n_real_device_result); const Tensor<CDataType> c_m_n_real_device_result_converted(c_m_n_real_device_result);
const Tensor<CDataType> c_m_n_imag_device_result_converted(c_m_n_imag_device_result); const Tensor<CDataType> c_m_n_imag_device_result_converted(c_m_n_imag_device_result);
result = ck::utils::check_err(c_m_n_real_device_result_converted, validator.check_err(c_m_n_real_device_result_converted,
c_m_n_real_host_result, c_m_n_real_host_result,
"Verification error: incorrect results in real part!", "Verification error: incorrect results in real part!",
1e-2f, 1e-2f,
1e-1f); 1e-1f);
result = result && ck::utils::check_err( validator.check_err(
c_m_n_imag_device_result_converted, c_m_n_imag_device_result_converted,
c_m_n_imag_host_result, c_m_n_imag_host_result,
"Verification error: incorrect results in imaginary part!", "Verification error: incorrect results in imaginary part!",
...@@ -235,12 +235,12 @@ bool run_cgemm_xdl(ck::index_t M, ...@@ -235,12 +235,12 @@ bool run_cgemm_xdl(ck::index_t M,
else else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{ {
result = ck::utils::check_err(c_m_n_real_device_result, validator.check_err(c_m_n_real_device_result,
c_m_n_real_host_result, c_m_n_real_host_result,
"Verification error: incorrect results in real part!", "Verification error: incorrect results in real part!",
1e-2f, 1e-2f,
1e-1f); 1e-1f);
result = result && ck::utils::check_err( validator.check_err(
c_m_n_imag_device_result, c_m_n_imag_device_result,
c_m_n_imag_host_result, c_m_n_imag_host_result,
"Verification error: incorrect results in imaginary part!", "Verification error: incorrect results in imaginary part!",
...@@ -248,7 +248,7 @@ bool run_cgemm_xdl(ck::index_t M, ...@@ -248,7 +248,7 @@ bool run_cgemm_xdl(ck::index_t M,
1e-1f); 1e-1f);
} }
return result; return validator.is_success();
} }
return true; return true;
} }
...@@ -240,13 +240,12 @@ int main(int argc, char* argv[]) ...@@ -240,13 +240,12 @@ int main(int argc, char* argv[])
auto invoker_ptr = device_instance.MakeInvokerPointer(); auto invoker_ptr = device_instance.MakeInvokerPointer();
bool pass = true;
if(args.do_verification) if(args.do_verification)
{ {
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
out_dev.FromDevice(out.mData.data()); out_dev.FromDevice(out.mData.data());
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
pass = pass && ck::utils::check_err(out, out_ref); validator.check_err(out, out_ref);
}; };
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
...@@ -260,5 +259,5 @@ int main(int argc, char* argv[]) ...@@ -260,5 +259,5 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name
<< std::endl; << std::endl;
return (pass ? 0 : 1); return !validator.is_success();
} }
...@@ -146,7 +146,6 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -146,7 +146,6 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
...@@ -174,10 +173,10 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -174,10 +173,10 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE #ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result); const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result);
pass &= ck::utils::check_err(e_device_result_converted, e_g_m_n_host_result); validator.check_err(e_device_result_converted, e_g_m_n_host_result);
#else #else
pass = ck::utils::check_err( validator.check_err(
e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c"); e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c");
#endif #endif
} }
...@@ -197,7 +196,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -197,7 +196,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
<< " GB/s, " << gemm.GetTypeString() << std::endl; << " GB/s, " << gemm.GetTypeString() << std::endl;
} }
return pass ? 0 : 1; return !validator.is_success();
} }
bool run_batched_gemm_example(int argc, char* argv[]) bool run_batched_gemm_example(int argc, char* argv[])
......
...@@ -390,7 +390,8 @@ int main(int argc, char* argv[]) ...@@ -390,7 +390,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; validator.check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -391,7 +391,8 @@ int main(int argc, char* argv[]) ...@@ -391,7 +391,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; validator.check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -286,7 +286,8 @@ int main(int argc, char* argv[]) ...@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; validator.check_err(e_ms_ns_device_result, e_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -286,7 +286,8 @@ int main(int argc, char* argv[]) ...@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; validator.check_err(e_ms_ns_device_result, e_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -269,7 +269,8 @@ int main(int argc, char* argv[]) ...@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; validator.check_err(e_ms_ns_device_result, e_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
...@@ -269,7 +269,8 @@ int main(int argc, char* argv[]) ...@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; validator.check_err(e_ms_ns_device_result, e_ms_ns_host_result);
return !validator.is_success();
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment