"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "b2ba34bde9229613559af6c8d5789dbdc6124f77"
Commit 6c3243f7 authored by Davis King's avatar Davis King
Browse files

Cleanup cuDNN conv algorithm selection code slightly by moving it into its own function.

parent 4d18e0d0
......@@ -776,6 +776,134 @@ namespace dlib
return best_alg;
}
void tensor_conv::
select_best_algorithms (
const tensor& data,
const tensor_descriptor& dest_desc
)
{
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
forward_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
#endif
forward_algo = forward_best_algo;
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_data_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_data_best_algo));
#endif
backward_data_algo = backward_data_best_algo;
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_filters_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_filters_best_algo));
#endif
// cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to
// incorrect outputs. So here we check if we are in a case where winograd
// isn't supported and manually overrule
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
// algorithm.
if (dnn_prefer_fastest_algorithms() &&
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
)
{
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
}
backward_filters_algo = backward_filters_best_algo;
}
void tensor_conv::
setup(
const tensor& data,
......@@ -863,81 +991,17 @@ namespace dlib
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionFwdAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
forward_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_FWD_PREFER_FASTEST:CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
#endif
forward_algo = forward_best_algo;
select_best_algorithms(data, dest_desc);
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(
context(),
descriptor(data),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(dest_desc),
forward_best_algo,
(cudnnConvolutionFwdAlgo_t)forward_algo,
&forward_workspace_size_in_bytes));
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdDataAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_data_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_data_best_algo));
#endif
backward_data_algo = backward_data_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(),
......@@ -945,55 +1009,9 @@ namespace dlib
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
backward_data_best_algo,
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
&backward_data_workspace_size_in_bytes));
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
#if CUDNN_MAJOR >= 8
{
int num_possilbe_algorithms = 0;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(context(), &num_possilbe_algorithms));
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> perf_results(num_possilbe_algorithms);
int num_algorithms = 0;
CHECK_CUDNN(cudnnFindConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
num_possilbe_algorithms,
&num_algorithms,
perf_results.data()));
perf_results.resize(num_algorithms);
backward_filters_best_algo = pick_best_algorithm(perf_results);
}
#else
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
dnn_prefer_fastest_algorithms()?CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST:CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&backward_filters_best_algo));
#endif
// cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to
// incorrect outputs. So here we check if we are in a case where winograd
// isn't supported and manually overrule
// cudnnGetConvolutionBackwardFilterAlgorithm() by picking a safe
// algorithm.
if (dnn_prefer_fastest_algorithms() &&
!(stride_x == 1 && stride_y == 1 && ((filters_nr==3&&filters_nc==3) || (filters_nr==5&&filters_nc==5)))
)
{
backward_filters_best_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
}
backward_filters_algo = backward_filters_best_algo;
CHECK_CUDNN(cudnnGetConvolutionBackwardFilterWorkspaceSize(
context(),
......@@ -1001,7 +1019,7 @@ namespace dlib
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
backward_filters_best_algo,
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
&backward_filters_workspace_size_in_bytes));
}
catch(...)
......
......@@ -228,6 +228,8 @@ namespace dlib
int out_nr;
int out_nc;
// sets the three _algo fields.
void select_best_algorithms(const tensor& data, const tensor_descriptor& dest_desc);
int forward_algo;
int backward_data_algo;
int backward_filters_algo;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment