Commit ce87bcc7 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

tmp

parent c8a8385f
......@@ -250,7 +250,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
int num_kernel = 0;
// profile device operation instances
......@@ -316,7 +316,7 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
{
h_device_buf.FromDevice(h_m_n.mData.data());
pass = pass && ck::utils::check_err(
validator.check_err(
h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
}
......@@ -327,19 +327,14 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
}
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
pass = false;
}
else
if(num_kernel != 0)
{
if(time_kernel)
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
}
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -281,6 +281,8 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
......@@ -343,9 +345,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(reduce0_m_device_result, reduce0_m_host_result);
validator.check_err(reduce1_m_device_result, reduce1_m_host_result);
if(do_log)
{
......@@ -376,6 +378,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
}
}
validator.is_success();
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
}
......
......@@ -158,7 +158,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
......@@ -215,7 +215,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
validator.check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
......@@ -227,7 +227,7 @@ bool profile_gemm_bilinear_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -147,7 +147,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
......@@ -203,7 +203,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
validator.check_err(e_m_n_device_result, e_m_n_host_result);
}
}
else
......@@ -215,7 +215,7 @@ bool profile_gemm_fastgelu_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -42,7 +42,7 @@ int profile_gemm_impl(int do_verification,
int StrideB,
int StrideC)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -188,7 +188,7 @@ int profile_gemm_impl(int do_verification,
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
......@@ -247,7 +247,7 @@ int profile_gemm_impl(int do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass ? 0 : 1;
return !validator.is_success();
}
} // namespace profiler
......
......@@ -250,6 +250,8 @@ bool profile_gemm_reduce_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::utils::CorrectnessValidator validator;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
......@@ -310,9 +312,10 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(reduce0_m_device_result, reduce0_m_host_result);
validator.check_err(reduce1_m_device_result, reduce1_m_host_result);
validator.is_success();
if(do_log)
{
......@@ -346,7 +349,7 @@ bool profile_gemm_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -43,7 +43,7 @@ bool profile_gemm_splitk_impl(int do_verification,
int StrideC,
int KBatch)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -181,7 +181,7 @@ bool profile_gemm_splitk_impl(int do_verification,
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
......@@ -221,12 +221,12 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
validator.check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
}
else
{
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
}
if(tflops > best_tflops)
......@@ -286,7 +286,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -43,7 +43,7 @@ bool profile_gemm_streamk_impl(int do_verification,
int StrideC,
uint32_t NumSKBlocks = 0xffffffff)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -176,7 +176,7 @@ bool profile_gemm_streamk_impl(int do_verification,
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
validator.check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
......@@ -260,7 +260,7 @@ bool profile_gemm_streamk_impl(int do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -122,7 +122,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
float best_gb_per_sec = 0;
// profile device op instances
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
......@@ -159,7 +159,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
{
in_device_buf.FromDevice(in_device.mData.data());
pass = pass & ck::utils::check_err(in_device, in_host);
validator.check_err(in_device, in_host);
if(do_log)
{
......@@ -250,7 +250,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -160,6 +160,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
ck::utils::CorrectnessValidator validator;
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr =
......@@ -214,15 +215,13 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
{
wei_device_buf.FromDevice(weight_device_result.mData.data());
bool pass = ck::utils::check_err(weight_device_result, weight_host_result);
validator.check_err(weight_device_result, weight_host_result);
if(!pass)
if(!validator.is_success())
{
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
}
all_pass &= pass;
if(do_log)
{
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",") << std::endl;
......@@ -250,7 +249,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return all_pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -142,7 +142,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
float best_gb_per_sec = 0;
// profile device op instances
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
......@@ -179,7 +179,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
{
out_device_buf.FromDevice(device_output.mData.data());
pass = pass & ck::utils::check_err(device_output, host_output);
validator.check_err(device_output, host_output);
if(do_log)
{
......@@ -246,7 +246,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -39,7 +39,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
const std::vector<int>& StrideCs)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -238,8 +238,7 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
ref_invoker.Run(ref_argument);
bool group_pass =
ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result);
pass = pass && group_pass;
validator.check_err(c_m_n_device_results[i], c_m_n_host_result);
std::cout << "group: " << i << " verification result: " << std::boolalpha
<< group_pass << std::endl;
......@@ -267,13 +266,13 @@ bool profile_grouped_gemm_fastgelu_impl(int do_verification,
if(do_verification)
{
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
std::cout << "Verification: " << (validator.is_success() ? "SUCCESS" : "FAILURE") << std::endl;
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -44,7 +44,7 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = true;
ck::utils::CorrectnessValidator validator;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -274,7 +274,7 @@ bool profile_grouped_gemm_impl(int do_verification,
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
instance_pass && validator.check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
......@@ -282,7 +282,7 @@ bool profile_grouped_gemm_impl(int do_verification,
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
instance_pass && validator.check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
......@@ -303,8 +303,6 @@ bool profile_grouped_gemm_impl(int do_verification,
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time =
......@@ -354,7 +352,7 @@ bool profile_grouped_gemm_impl(int do_verification,
<< std::endl;
}
return pass;
return validator.is_success();
}
} // namespace profiler
......
......@@ -110,7 +110,7 @@ bool profile_groupnorm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -169,7 +169,7 @@ bool profile_groupnorm_impl(int do_verification,
{
y_dev.FromDevice(y.mData.data());
bool pass = ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
bool pass = validator.check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log)
{
......@@ -182,7 +182,6 @@ bool profile_groupnorm_impl(int do_verification,
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
......@@ -198,14 +197,7 @@ bool profile_groupnorm_impl(int do_verification,
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
return validator.is_success();
}
} // namespace profiler
......
......@@ -121,7 +121,7 @@ bool profile_layernorm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -186,7 +186,7 @@ bool profile_layernorm_impl(int do_verification,
y_dev.FromDevice(y.mData.data());
bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
validator.check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
if(do_log)
{
......@@ -199,7 +199,6 @@ bool profile_layernorm_impl(int do_verification,
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
......@@ -218,13 +217,7 @@ bool profile_layernorm_impl(int do_verification,
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
return validator.is_success();
}
} // namespace profiler
......
......@@ -150,7 +150,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
int num_kernel = 0;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instance_ptrs)
{
......@@ -213,7 +213,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
bool pass = validator.check_err(out_n_c_do_ho_wo_device.mData,
out_n_c_do_ho_wo_host.mData,
"Error: Incorrect results",
1e-3,
......@@ -223,7 +223,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data());
pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device,
pass = pass && validator.check_err(out_indices_n_c_do_ho_wo_device,
out_indices_n_c_do_ho_wo_host);
}
......@@ -250,7 +250,6 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl;
return false;
}
else
{
......@@ -267,13 +266,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
return validator.is_success();
}
} // namespace profiler
......
......@@ -196,7 +196,7 @@ bool profile_reduce_impl_impl(bool do_verification,
invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6);
int num_kernel = 0;
bool pass = true;
ck::utils::CorrectnessValidator validator;
if constexpr(!invalid_reduce)
{
......@@ -403,12 +403,12 @@ bool profile_reduce_impl_impl(bool do_verification,
bool single_pass;
out_dev.FromDevice(out.mData.data());
single_pass = ck::utils::check_err(out, out_ref);
single_pass = validator.check_err(out, out_ref);
if(OutputIndex)
{
out_indices_dev.FromDevice(out_indices.mData.data());
single_pass = single_pass && ck::utils::check_err(out_indices, out_indices_ref);
single_pass = single_pass && validator.check_err(out_indices, out_indices_ref);
};
if(!single_pass)
......@@ -416,7 +416,6 @@ bool profile_reduce_impl_impl(bool do_verification,
std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl;
}
pass = pass && single_pass;
};
if(do_dumpout)
......@@ -447,13 +446,7 @@ bool profile_reduce_impl_impl(bool do_verification,
"The requested reduction operation is not supported, please check!");
};
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
};
return pass;
return validator.is_success();
};
template <typename InDataType, typename AccDataType, typename OutDataType>
......
......@@ -123,7 +123,7 @@ bool profile_softmax_impl(int do_verification,
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
std::vector<bool> instance_pass;
ck::utils::CorrectnessValidator validator;
for(auto& inst_ptr : instances)
{
......@@ -176,7 +176,7 @@ bool profile_softmax_impl(int do_verification,
bool pass = true;
if(std::is_same<InDataType, int8_t>::value)
{
pass = pass && ck::utils::check_err(
pass = pass && validator.check_err(
out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1);
if(do_log)
{
......@@ -188,7 +188,7 @@ bool profile_softmax_impl(int do_verification,
}
else
{
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
pass = pass && validator.check_err(out.mData, out_ref.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
......@@ -219,8 +219,7 @@ bool profile_softmax_impl(int do_verification,
<< "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec
<< " GB/s, " << best_instance_name << std::endl;
}
return std::all_of(
std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; });
return validator.is_success();
}
} // namespace profiler
......
......@@ -46,110 +46,113 @@ class TestConvUtil : public ::testing::Test
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
{
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1
SetNDParams(1, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
// stride 1, dilation 1, pad 1
SetNDParams(1, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
// stride 2, dilation 1, pad 2
SetNDParams(1, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37},
"Error: ConvParams 1D padding left/right {2}."));
// stride 2, dilation 2, pad 2
SetNDParams(1, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
// stride 3, dilation 2, pad 1
SetNDParams(1, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(
ck::utils::check_err(out_spatial_len,
validator.check_err(out_spatial_len,
std::vector<ck::index_t>{23},
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
}
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
{
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1
SetNDParams(2, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36},
"Error: ConvParams 2D default constructor."));
// stride 1, dilation 1, pad 1
SetNDParams(2, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
// stride 2, dilation 1, pad 2
SetNDParams(2, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37},
"Error: ConvParams 2D padding left/right {2,2}."));
// stride 2, dilation 2, pad 2
SetNDParams(2, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
// stride 3, dilation 2, pad 1
SetNDParams(2, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(
ck::utils::check_err(out_spatial_len,
validator.check_err(out_spatial_len,
std::vector<ck::index_t>{23, 23},
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
}
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{
ck::utils::CorrectnessValidator validator;
// stride 2, dilation 1, pad 1
SetNDParams(3, 2, 1, 1);
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
// stride 1, dilation 1, pad 1
SetNDParams(3, 1, 1, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{71, 71, 71},
"Error: ConvParams 3D stride {1, 1, 1}."));
// stride 2, dilation 1, pad 2
SetNDParams(3, 2, 1, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{37, 37, 37},
"Error: ConvParams 3D padding left/right {2, 2, 2}."));
// stride 2, dilation 2, pad 2
SetNDParams(3, 2, 2, 2);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
EXPECT_TRUE(validator.check_err(out_spatial_len,
std::vector<ck::index_t>{36, 36, 36},
"Error: ConvParams 3D dilation {2, 2, 2}."));
// stride 3, dilation 2, pad 1
SetNDParams(3, 3, 2, 1);
out_spatial_len = conv_params.GetOutputSpatialLengths();
EXPECT_TRUE(ck::utils::check_err(
EXPECT_TRUE(validator.check_err(
out_spatial_len,
std::vector<ck::index_t>{23, 23, 23},
"Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
......
......@@ -227,30 +227,30 @@ struct TestGemm
if(is_supported && do_verification)
{
// Assert
bool res = false;
ck::utils::CorrectnessValidator validator;
if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device, c_host);
res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = ck::utils::check_err(c_device, c_host);
res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::bhalf_t>::value)
{
res = ck::utils::check_err(c_device, c_host);
res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device, c_host);
res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device, c_host);
res = validator.check_err(c_device, c_host);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment