Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
6c3243f7
Commit
6c3243f7
authored
Aug 01, 2020
by
Davis King
Browse files
Cleanup cuDNN conv algorithm selection code slightly by moving it into its own function.
parent
4d18e0d0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
115 deletions
+135
-115
dlib/cuda/cudnn_dlibapi.cpp
dlib/cuda/cudnn_dlibapi.cpp
+133
-115
dlib/cuda/cudnn_dlibapi.h
dlib/cuda/cudnn_dlibapi.h
+2
-0
No files found.
dlib/cuda/cudnn_dlibapi.cpp
View file @
6c3243f7
...
@@ -777,92 +777,11 @@ namespace dlib
...
@@ -777,92 +777,11 @@ namespace dlib
}
}
void
tensor_conv
::
void
tensor_conv
::
se
tup
(
se
lect_best_algorithms
(
const
tensor
&
data
,
const
tensor
&
data
,
const
tensor
&
filters
,
const
tensor_descriptor
&
dest_desc
int
stride_y_
,
int
stride_x_
,
int
padding_y_
,
int
padding_x_
)
)
{
{
DLIB_CASSERT
(
data
.
k
()
==
filters
.
k
());
// if the last call to setup gave the same exact settings then don't do
// anything.
if
(
stride_y_
==
stride_y
&&
stride_x_
==
stride_x
&&
padding_y_
==
padding_y
&&
padding_x_
==
padding_x
&&
data_num_samples
==
data
.
num_samples
()
&&
data_k
==
data
.
k
()
&&
data_nr
==
data
.
nr
()
&&
data_nc
==
data
.
nc
()
&&
filters_num_samples
==
filters
.
num_samples
()
&&
filters_k
==
filters
.
k
()
&&
filters_nr
==
filters
.
nr
()
&&
filters_nc
==
filters
.
nc
())
{
return
;
}
clear
();
try
{
stride_y
=
stride_y_
;
stride_x
=
stride_x_
;
padding_y
=
padding_y_
;
padding_x
=
padding_x_
;
data_num_samples
=
data
.
num_samples
();
data_k
=
data
.
k
();
data_nr
=
data
.
nr
();
data_nc
=
data
.
nc
();
filters_num_samples
=
filters
.
num_samples
();
filters_k
=
filters
.
k
();
filters_nr
=
filters
.
nr
();
filters_nc
=
filters
.
nc
();
CHECK_CUDNN
(
cudnnCreateFilterDescriptor
((
cudnnFilterDescriptor_t
*
)
&
filter_handle
));
CHECK_CUDNN
(
cudnnSetFilter4dDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
,
CUDNN_DATA_FLOAT
,
CUDNN_TENSOR_NCHW
,
filters
.
num_samples
(),
filters
.
k
(),
filters
.
nr
(),
filters
.
nc
()));
CHECK_CUDNN
(
cudnnCreateConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
*
)
&
conv_handle
));
#if CUDNN_MAJOR >= 6
CHECK_CUDNN
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
padding_y
,
// vertical padding
padding_x
,
// horizontal padding
stride_y
,
stride_x
,
1
,
1
,
// must be 1,1
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
// could also be CUDNN_CONVOLUTION
#else
CHECK_CUDNN
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
padding_y
,
// vertical padding
padding_x
,
// horizontal padding
stride_y
,
stride_x
,
1
,
1
,
// must be 1,1
CUDNN_CROSS_CORRELATION
));
// could also be CUDNN_CONVOLUTION
#endif
CHECK_CUDNN
(
cudnnGetConvolution2dForwardOutputDim
(
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
&
out_num_samples
,
&
out_k
,
&
out_nr
,
&
out_nc
));
tensor_descriptor
dest_desc
;
dest_desc
.
set_size
(
out_num_samples
,
out_k
,
out_nr
,
out_nc
);
// Pick which forward algorithm we will use and allocate the necessary
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
// workspace buffer.
cudnnConvolutionFwdAlgo_t
forward_best_algo
;
cudnnConvolutionFwdAlgo_t
forward_best_algo
;
...
@@ -896,14 +815,8 @@ namespace dlib
...
@@ -896,14 +815,8 @@ namespace dlib
&
forward_best_algo
));
&
forward_best_algo
));
#endif
#endif
forward_algo
=
forward_best_algo
;
forward_algo
=
forward_best_algo
;
CHECK_CUDNN
(
cudnnGetConvolutionForwardWorkspaceSize
(
context
(),
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
dest_desc
),
forward_best_algo
,
&
forward_workspace_size_in_bytes
));
// Pick which backward data algorithm we will use and allocate the
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
// necessary workspace buffer.
...
@@ -939,14 +852,8 @@ namespace dlib
...
@@ -939,14 +852,8 @@ namespace dlib
#endif
#endif
backward_data_algo
=
backward_data_best_algo
;
backward_data_algo
=
backward_data_best_algo
;
CHECK_CUDNN
(
cudnnGetConvolutionBackwardDataWorkspaceSize
(
context
(),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
descriptor
(
dest_desc
),
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
data
),
backward_data_best_algo
,
&
backward_data_workspace_size_in_bytes
));
// Pick which backward filters algorithm we will use and allocate the
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
// necessary workspace buffer.
...
@@ -980,6 +887,7 @@ namespace dlib
...
@@ -980,6 +887,7 @@ namespace dlib
std
::
numeric_limits
<
size_t
>::
max
(),
std
::
numeric_limits
<
size_t
>::
max
(),
&
backward_filters_best_algo
));
&
backward_filters_best_algo
));
#endif
#endif
// cuDNN 5.1 has a bug that causes
// cuDNN 5.1 has a bug that causes
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// cudnnGetConvolutionBackwardFilterAlgorithm() to pick the winograd
// algorithm even for cases where cuDNN doesn't support it, leading to
// algorithm even for cases where cuDNN doesn't support it, leading to
...
@@ -994,6 +902,116 @@ namespace dlib
...
@@ -994,6 +902,116 @@ namespace dlib
backward_filters_best_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
;
backward_filters_best_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
;
}
}
backward_filters_algo
=
backward_filters_best_algo
;
backward_filters_algo
=
backward_filters_best_algo
;
}
void
tensor_conv
::
setup
(
const
tensor
&
data
,
const
tensor
&
filters
,
int
stride_y_
,
int
stride_x_
,
int
padding_y_
,
int
padding_x_
)
{
DLIB_CASSERT
(
data
.
k
()
==
filters
.
k
());
// if the last call to setup gave the same exact settings then don't do
// anything.
if
(
stride_y_
==
stride_y
&&
stride_x_
==
stride_x
&&
padding_y_
==
padding_y
&&
padding_x_
==
padding_x
&&
data_num_samples
==
data
.
num_samples
()
&&
data_k
==
data
.
k
()
&&
data_nr
==
data
.
nr
()
&&
data_nc
==
data
.
nc
()
&&
filters_num_samples
==
filters
.
num_samples
()
&&
filters_k
==
filters
.
k
()
&&
filters_nr
==
filters
.
nr
()
&&
filters_nc
==
filters
.
nc
())
{
return
;
}
clear
();
try
{
stride_y
=
stride_y_
;
stride_x
=
stride_x_
;
padding_y
=
padding_y_
;
padding_x
=
padding_x_
;
data_num_samples
=
data
.
num_samples
();
data_k
=
data
.
k
();
data_nr
=
data
.
nr
();
data_nc
=
data
.
nc
();
filters_num_samples
=
filters
.
num_samples
();
filters_k
=
filters
.
k
();
filters_nr
=
filters
.
nr
();
filters_nc
=
filters
.
nc
();
CHECK_CUDNN
(
cudnnCreateFilterDescriptor
((
cudnnFilterDescriptor_t
*
)
&
filter_handle
));
CHECK_CUDNN
(
cudnnSetFilter4dDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
,
CUDNN_DATA_FLOAT
,
CUDNN_TENSOR_NCHW
,
filters
.
num_samples
(),
filters
.
k
(),
filters
.
nr
(),
filters
.
nc
()));
CHECK_CUDNN
(
cudnnCreateConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
*
)
&
conv_handle
));
#if CUDNN_MAJOR >= 6
CHECK_CUDNN
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
padding_y
,
// vertical padding
padding_x
,
// horizontal padding
stride_y
,
stride_x
,
1
,
1
,
// must be 1,1
CUDNN_CROSS_CORRELATION
,
CUDNN_DATA_FLOAT
));
// could also be CUDNN_CONVOLUTION
#else
CHECK_CUDNN
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
padding_y
,
// vertical padding
padding_x
,
// horizontal padding
stride_y
,
stride_x
,
1
,
1
,
// must be 1,1
CUDNN_CROSS_CORRELATION
));
// could also be CUDNN_CONVOLUTION
#endif
CHECK_CUDNN
(
cudnnGetConvolution2dForwardOutputDim
(
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
&
out_num_samples
,
&
out_k
,
&
out_nr
,
&
out_nc
));
tensor_descriptor
dest_desc
;
dest_desc
.
set_size
(
out_num_samples
,
out_k
,
out_nr
,
out_nc
);
select_best_algorithms
(
data
,
dest_desc
);
CHECK_CUDNN
(
cudnnGetConvolutionForwardWorkspaceSize
(
context
(),
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
dest_desc
),
(
cudnnConvolutionFwdAlgo_t
)
forward_algo
,
&
forward_workspace_size_in_bytes
));
CHECK_CUDNN
(
cudnnGetConvolutionBackwardDataWorkspaceSize
(
context
(),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
descriptor
(
dest_desc
),
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
data
),
(
cudnnConvolutionBwdDataAlgo_t
)
backward_data_algo
,
&
backward_data_workspace_size_in_bytes
));
CHECK_CUDNN
(
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
CHECK_CUDNN
(
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
context
(),
context
(),
...
@@ -1001,7 +1019,7 @@ namespace dlib
...
@@ -1001,7 +1019,7 @@ namespace dlib
descriptor
(
dest_desc
),
descriptor
(
dest_desc
),
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
backward_filters_
best_
algo
,
(
cudnnConvolutionBwdFilterAlgo_t
)
backward_filters_algo
,
&
backward_filters_workspace_size_in_bytes
));
&
backward_filters_workspace_size_in_bytes
));
}
}
catch
(...)
catch
(...)
...
...
dlib/cuda/cudnn_dlibapi.h
View file @
6c3243f7
...
@@ -228,6 +228,8 @@ namespace dlib
...
@@ -228,6 +228,8 @@ namespace dlib
int
out_nr
;
int
out_nr
;
int
out_nc
;
int
out_nc
;
// sets the three _algo fields.
void
select_best_algorithms
(
const
tensor
&
data
,
const
tensor_descriptor
&
dest_desc
);
int
forward_algo
;
int
forward_algo
;
int
backward_data_algo
;
int
backward_data_algo
;
int
backward_filters_algo
;
int
backward_filters_algo
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment