Commit 9db34134 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Fail when no kernel is applicable

parent 8f84a012
......@@ -123,7 +123,9 @@ bool profile_softmax_impl(int do_verification,
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
std::vector<bool> instance_pass;
int num_kernel = 0;
bool pass = true;
for(auto& inst_ptr : instances)
{
......@@ -144,7 +146,6 @@ bool profile_softmax_impl(int do_verification,
<< "], "
<< "scaler = [" << alpha << ", " << beta << "]";
LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl;
instance_pass.push_back(true);
continue;
}
......@@ -173,10 +174,11 @@ bool profile_softmax_impl(int do_verification,
if(do_verification)
{
out_dev.FromDevice(out.data());
bool pass = true;
bool correct = true;
if(std::is_same<InDataType, int8_t>::value)
{
pass = pass && ck::utils::check_err(
correct =
correct && ck::utils::check_err(
out.mData, out_ref.mData, "Error: Incorrect results!", 0, 1);
if(do_log)
{
......@@ -188,7 +190,7 @@ bool profile_softmax_impl(int do_verification,
}
else
{
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
correct = correct && ck::utils::check_err(out.mData, out_ref.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
......@@ -198,16 +200,24 @@ bool profile_softmax_impl(int do_verification,
}
}
if(!pass)
if(!correct)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "input lengths = [", in_length, ", ")
<< "], "
<< "scaler = [" << alpha << ", " << beta << "]." << std::endl;
}
instance_pass.push_back(pass);
num_kernel++;
pass &= correct;
}
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(time_kernel)
{
std::cout << "Best Perf for datatype = " << type_to_string<InDataType>() << "_"
......@@ -219,8 +229,7 @@ bool profile_softmax_impl(int do_verification,
<< "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec
<< " GB/s, " << best_instance_name << std::endl;
}
return std::all_of(
std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; });
return pass;
}
} // namespace profiler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment