Commit 6c3243f7 authored by Davis King's avatar Davis King
Browse files

Cleanup cuDNN conv algorithm selection code slightly by moving it into its own function.

parent 4d18e0d0
...@@ -777,92 +777,11 @@ namespace dlib ...@@ -777,92 +777,11 @@ namespace dlib
} }
void tensor_conv:: void tensor_conv::
setup( select_best_algorithms (
const tensor& data, const tensor& data,
const tensor& filters, const tensor_descriptor& dest_desc
int stride_y_,
int stride_x_,
int padding_y_,
int padding_x_
) )
{ {
DLIB_CASSERT(data.k() == filters.k());
// if the last call to setup gave the same exact settings then don't do
// anything.
if (stride_y_ == stride_y &&
stride_x_ == stride_x &&
padding_y_ == padding_y &&
padding_x_ == padding_x &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc() &&
filters_num_samples == filters.num_samples() &&
filters_k == filters.k() &&
filters_nr == filters.nr() &&
filters_nc == filters.nc())
{
return;
}
clear();
try
{
stride_y = stride_y_;
stride_x = stride_x_;
padding_y = padding_y_;
padding_x = padding_x_;
data_num_samples = data.num_samples();
data_k = data.k();
data_nr = data.nr();
data_nc = data.nc();
filters_num_samples = filters.num_samples();
filters_k = filters.k();
filters_nr = filters.nr();
filters_nc = filters.nc();
CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
filters.num_samples(),
filters.k(),
filters.nr(),
filters.nc()));
CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle));
#if CUDNN_MAJOR >= 6
CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
padding_y, // vertical padding
padding_x, // horizontal padding
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT)); // could also be CUDNN_CONVOLUTION
#else
CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
padding_y, // vertical padding
padding_x, // horizontal padding
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION
#endif
CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
&out_num_samples,
&out_k,
&out_nr,
&out_nc));
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Pick which forward algorithm we will use and allocate the necessary // Pick which forward algorithm we will use and allocate the necessary
// workspace buffer. // workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo; cudnnConvolutionFwdAlgo_t forward_best_algo;
...@@ -896,14 +815,8 @@ namespace dlib ...@@ -896,14 +815,8 @@ namespace dlib
&forward_best_algo)); &forward_best_algo));
#endif #endif
forward_algo = forward_best_algo; forward_algo = forward_best_algo;
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
forward_best_algo,
&forward_workspace_size_in_bytes));
// Pick which backward data algorithm we will use and allocate the // Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
...@@ -939,14 +852,8 @@ namespace dlib ...@@ -939,14 +852,8 @@ namespace dlib
#endif #endif
backward_data_algo = backward_data_best_algo; backward_data_algo = backward_data_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
backward_data_best_algo,
&backward_data_workspace_size_in_bytes));
// Pick which backward filters algorithm we will use and allocate the // Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer. // necessary workspace buffer.
...@@ -980,6 +887,7 @@ namespace dlib ...@@ -980,6 +887,7 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&backward_filters_best_algo)); &backward_filters_best_algo));
#endif #endif
// cuDNN 5.1 has a bug that causes // cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd // cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to // algorithm even for cases where cuDNN doesn't support it, leading to
...@@ -994,6 +902,116 @@ namespace dlib ...@@ -994,6 +902,116 @@ namespace dlib
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
} }
backward_filters_algo = backward_filters_best_algo; backward_filters_algo = backward_filters_best_algo;
}
void tensor_conv::
setup(
const tensor& data,
const tensor& filters,
int stride_y_,
int stride_x_,
int padding_y_,
int padding_x_
)
{
DLIB_CASSERT(data.k() == filters.k());
// if the last call to setup gave the same exact settings then don't do
// anything.
if (stride_y_ == stride_y &&
stride_x_ == stride_x &&
padding_y_ == padding_y &&
padding_x_ == padding_x &&
data_num_samples == data.num_samples() &&
data_k == data.k() &&
data_nr == data.nr() &&
data_nc == data.nc() &&
filters_num_samples == filters.num_samples() &&
filters_k == filters.k() &&
filters_nr == filters.nr() &&
filters_nc == filters.nc())
{
return;
}
clear();
try
{
stride_y = stride_y_;
stride_x = stride_x_;
padding_y = padding_y_;
padding_x = padding_x_;
data_num_samples = data.num_samples();
data_k = data.k();
data_nr = data.nr();
data_nc = data.nc();
filters_num_samples = filters.num_samples();
filters_k = filters.k();
filters_nr = filters.nr();
filters_nc = filters.nc();
CHECK_CUDNN(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
CHECK_CUDNN(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW,
filters.num_samples(),
filters.k(),
filters.nr(),
filters.nc()));
CHECK_CUDNN(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle));
#if CUDNN_MAJOR >= 6
CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
padding_y, // vertical padding
padding_x, // horizontal padding
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT)); // could also be CUDNN_CONVOLUTION
#else
CHECK_CUDNN(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
padding_y, // vertical padding
padding_x, // horizontal padding
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CROSS_CORRELATION)); // could also be CUDNN_CONVOLUTION
#endif
CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
&out_num_samples,
&out_k,
&out_nr,
&out_nc));
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
select_best_algorithms(data, dest_desc);
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
(cudnnConvolutionFwdAlgo_t)forward_algo,
&forward_workspace_size_in_bytes));
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
&backward_data_workspace_size_in_bytes));
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize( CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
context(), context(),
...@@ -1001,7 +1019,7 @@ namespace dlib ...@@ -1001,7 +1019,7 @@ namespace dlib
descriptor(dest_desc), descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle, (const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle, (const cudnnFilterDescriptor_t)filter_handle,
backward_filters_best_algo, (cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
&backward_filters_workspace_size_in_bytes)); &backward_filters_workspace_size_in_bytes));
} }
catch(...) catch(...)
......
...@@ -228,6 +228,8 @@ namespace dlib ...@@ -228,6 +228,8 @@ namespace dlib
int out_nr; int out_nr;
int out_nc; int out_nc;
// sets the three _algo fields.
void select_best_algorithms(const tensor& data, const tensor_descriptor& dest_desc);
int forward_algo; int forward_algo;
int backward_data_algo; int backward_data_algo;
int backward_filters_algo; int backward_filters_algo;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment