Commit 2c70aad1 authored by Davis King's avatar Davis King
Browse files

Use a cache to avoid calls to the cuDNN algorithm selection routines.

parent 8910445a
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "cudnn_dlibapi.h" #include "cudnn_dlibapi.h"
#include "tensor.h" #include "tensor.h"
#include <cudnn.h> #include <cudnn.h>
#include <tuple>
#include <map>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -782,6 +784,22 @@ namespace dlib ...@@ -782,6 +784,22 @@ namespace dlib
const tensor_descriptor& dest_desc const tensor_descriptor& dest_desc
) )
{ {
// Calling the cuDNN "find the best algorithm" functions are really slow. So we keep a
// cache that tells us what method was best for a particular configuration.
thread_local std::map<std::tuple<int,int,int,int,long,long>,
std::tuple<int,int,int>> config_to_algo_cache;
// If we have already found good algorithms for this setting then just pull them from
// the cache.
const auto cache_key = std::make_tuple(stride_y, stride_x, padding_y, padding_x, filters_nr, filters_nc);
const auto iter = config_to_algo_cache.find(cache_key);
if (iter != config_to_algo_cache.end())
{
std::tie(forward_algo, backward_data_algo, backward_filters_algo) = iter->second;
return;
}
// Pick which forward algorithm we will use and allocate the necessary // Pick which forward algorithm we will use and allocate the necessary
// workspace buffer. // workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo; cudnnConvolutionFwdAlgo_t forward_best_algo;
...@@ -902,6 +920,10 @@ namespace dlib ...@@ -902,6 +920,10 @@ namespace dlib
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} }
backward_filters_algo = backward_filters_best_algo; backward_filters_algo = backward_filters_best_algo;
// Save this algorithm selection in the cache
config_to_algo_cache[cache_key] = std::make_tuple(forward_algo, backward_data_algo, backward_filters_algo);
} }
void tensor_conv:: void tensor_conv::
...@@ -916,7 +938,12 @@ namespace dlib ...@@ -916,7 +938,12 @@ namespace dlib
{ {
DLIB_CASSERT(data.k() == filters.k()); DLIB_CASSERT(data.k() == filters.k());
const bool non_data_params_unchanged = // if the last call to setup gave the same exact settings then don't do
// anything.
if (data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc() &&
stride_y_ == stride_y && stride_y_ == stride_y &&
stride_x_ == stride_x && stride_x_ == stride_x &&
padding_y_ == padding_y && padding_y_ == padding_y &&
...@@ -924,15 +951,7 @@ namespace dlib ...@@ -924,15 +951,7 @@ namespace dlib
filters_num_samples == filters.num_samples() && filters_num_samples == filters.num_samples() &&
filters_k == filters.k() && filters_k == filters.k() &&
filters_nr == filters.nr() && filters_nr == filters.nr() &&
filters_nc == filters.nc(); filters_nc == filters.nc()
// if the last call to setup gave the same exact settings then don't do
// anything.
if (non_data_params_unchanged &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc()
) )
{ {
return; return;
...@@ -995,16 +1014,7 @@ namespace dlib ...@@ -995,16 +1014,7 @@ namespace dlib
tensor_descriptor dest_desc; tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc); dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Ask cuDNN what algorithms are best to use. We always do this on the first call
// to setup(). Then if something other than the size of the input tensor changes we
// also ask cuDNN what to use. Note that in newer versions of cuDNN, asking for the
// best algorithm is a relatively slow thing. So it's important we don't do it
// unnecessarily.
if (!selected_algos || !non_data_params_unchanged)
{
selected_algos = true;
select_best_algorithms(data, dest_desc); select_best_algorithms(data, dest_desc);
}
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(), context(),
......
...@@ -233,8 +233,6 @@ namespace dlib ...@@ -233,8 +233,6 @@ namespace dlib
int forward_algo; int forward_algo;
int backward_data_algo; int backward_data_algo;
int backward_filters_algo; int backward_filters_algo;
// true if select_best_algorithms has been called at least once.
bool selected_algos = false;
size_t forward_workspace_size_in_bytes; size_t forward_workspace_size_in_bytes;
size_t backward_data_workspace_size_in_bytes; size_t backward_data_workspace_size_in_bytes;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment