Commit f8cfe639 authored by Davis King's avatar Davis King
Browse files

Avoid unnecessairly asking cuDNN which algorithms to use, since this is slow in cuDNN 8.0

parent 6c3243f7
...@@ -916,20 +916,24 @@ namespace dlib ...@@ -916,20 +916,24 @@ namespace dlib
{ {
DLIB_CASSERT(data.k() == filters.k()); DLIB_CASSERT(data.k() == filters.k());
// if the last call to setup gave the same exact settings then don't do const bool non_data_params_unchanged =
// anything. stride_y_ == stride_y &&
if (stride_y_ == stride_y &&
stride_x_ == stride_x && stride_x_ == stride_x &&
padding_y_ == padding_y && padding_y_ == padding_y &&
padding_x_ == padding_x && padding_x_ == padding_x &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc() &&
filters_num_samples == filters.num_samples() && filters_num_samples == filters.num_samples() &&
filters_k == filters.k() && filters_k == filters.k() &&
filters_nr == filters.nr() && filters_nr == filters.nr() &&
filters_nc == filters.nc()) filters_nc == filters.nc();
// if the last call to setup gave the same exact settings then don't do
// anything.
if (non_data_params_unchanged &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc()
)
{ {
return; return;
} }
...@@ -991,7 +995,16 @@ namespace dlib ...@@ -991,7 +995,16 @@ namespace dlib
tensor_descriptor dest_desc; tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc); dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Ask cuDNN what algorithms are best to use. We always do this on the first call
// to setup(). Then if something other than the size of the input tensor changes we
// also ask cuDNN what to use. Note that in newer versions of cuDNN, asking for the
// best algorithm is a relatively slow thing. So it's important we don't do it
// unnecessarily.
if (!selected_algos || !non_data_params_unchanged)
{
selected_algos = true;
select_best_algorithms(data, dest_desc); select_best_algorithms(data, dest_desc);
}
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(), context(),
......
...@@ -233,6 +233,8 @@ namespace dlib ...@@ -233,6 +233,8 @@ namespace dlib
int forward_algo; int forward_algo;
int backward_data_algo; int backward_data_algo;
int backward_filters_algo; int backward_filters_algo;
// true if select_best_algorithms has been called at least once.
bool selected_algos = false;
size_t forward_workspace_size_in_bytes; size_t forward_workspace_size_in_bytes;
size_t backward_data_workspace_size_in_bytes; size_t backward_data_workspace_size_in_bytes;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment