Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
d63c4682
"vscode:/vscode.git/clone" did not exist on "2004d60f8bbbb62050db623bbddf030ec82aa26b"
Commit
d63c4682
authored
Oct 19, 2015
by
Davis King
Browse files
Added some of the cuDNN conv calls.
parent
1022b5b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
3 deletions
+95
-3
dlib/dnn/cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+73
-2
dlib/dnn/cudnn_dlibapi.h
dlib/dnn/cudnn_dlibapi.h
+22
-1
No files found.
dlib/dnn/cudnn_dlibapi.cpp
View file @
d63c4682
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "tensor.h"
#include "tensor.h"
#include <cudnn.h>
#include <cudnn.h>
#include <iostream>
#include <iostream>
#include <string>
#include "cuda_utils.h"
#include "cuda_utils.h"
...
@@ -31,10 +32,11 @@ namespace dlib
...
@@ -31,10 +32,11 @@ namespace dlib
case
CUDNN_STATUS_BAD_PARAM
:
case
CUDNN_STATUS_BAD_PARAM
:
throw
cudnn_error
(
"CUDNN_STATUS_BAD_PARAM"
);
throw
cudnn_error
(
"CUDNN_STATUS_BAD_PARAM"
);
default:
default:
throw
cudnn_error
(
"A call to cuDNN failed
."
);
throw
cudnn_error
(
"A call to cuDNN failed
: "
+
std
::
string
(
cudnnGetErrorString
(
s
))
);
}
}
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
cudnn_context
::
cudnn_context
()
:
handle
(
nullptr
)
cudnn_context
::
cudnn_context
()
:
handle
(
nullptr
)
...
@@ -162,14 +164,83 @@ namespace dlib
...
@@ -162,14 +164,83 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
conv
::
conv
(
conv
::
conv
(
)
:
filter_handle
(
nullptr
),
conv_handle
(
nullptr
),
out_num_samples
(
0
),
out_k
(
0
),
out_nr
(
0
),
out_nc
(
0
)
{
}
void
conv
::
clear
(
)
{
if
(
filter_handle
)
cudnnDestroyFilterDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
);
if
(
conv_handle
)
cudnnDestroyConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
);
filter_handle
=
nullptr
;
conv_handle
=
nullptr
;
out_num_samples
=
0
;
out_k
=
0
;
out_nr
=
0
;
out_nc
=
0
;
}
void
conv
::
setup
(
cudnn_context
&
context
,
cudnn_context
&
context
,
const
tensor
&
data
,
const
tensor
&
data
,
const
tensor
&
filters
,
const
tensor
&
filters
,
int
stride_y
,
int
stride_y
,
int
stride_x
int
stride_x
)
{
clear
();
try
{
check
(
cudnnCreateFilterDescriptor
((
cudnnFilterDescriptor_t
*
)
&
filter_handle
));
check
(
cudnnSetFilter4dDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
,
CUDNN_DATA_FLOAT
,
filters
.
num_samples
(),
filters
.
k
(),
filters
.
nr
(),
filters
.
nc
()));
check
(
cudnnCreateConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
*
)
&
conv_handle
));
check
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
filters
.
nr
()
/
2
,
// vertical padding
filters
.
nc
()
/
2
,
// horizontal padding
stride_y
,
stride_x
,
1
,
1
,
// must be 1,1
CUDNN_CONVOLUTION
));
// could also be CUDNN_CROSS_CORRELATION
check
(
cudnnGetConvolution2dForwardOutputDim
(
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
(
const
cudnnTensorDescriptor_t
)
data
.
get_cudnn_tensor_descriptor
().
get_handle
(),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
&
out_num_samples
,
&
out_k
,
&
out_nr
,
&
out_nc
));
}
catch
(...)
{
clear
();
}
}
conv
::
~
conv
(
)
)
{
{
clear
();
}
}
void
conv
::
operator
()
(
void
conv
::
operator
()
(
...
...
dlib/dnn/cudnn_dlibapi.h
View file @
d63c4682
...
@@ -66,6 +66,10 @@ namespace dlib
...
@@ -66,6 +66,10 @@ namespace dlib
int
nc
,
int
nc
,
int
k
int
k
);
);
/*!
ensures
- if any of the arguments are 0 then they are all set to 0 in the tensor.
!*/
void
get_size
(
void
get_size
(
int
&
n
,
int
&
n
,
...
@@ -145,7 +149,12 @@ namespace dlib
...
@@ -145,7 +149,12 @@ namespace dlib
conv
(
const
conv
&
)
=
delete
;
conv
(
const
conv
&
)
=
delete
;
conv
&
operator
=
(
const
conv
&
)
=
delete
;
conv
&
operator
=
(
const
conv
&
)
=
delete
;
conv
(
conv
();
void
clear
(
);
void
setup
(
cudnn_context
&
context
,
cudnn_context
&
context
,
const
tensor
&
data
,
const
tensor
&
data
,
const
tensor
&
filters
,
const
tensor
&
filters
,
...
@@ -153,6 +162,9 @@ namespace dlib
...
@@ -153,6 +162,9 @@ namespace dlib
int
stride_x
int
stride_x
);
);
~
conv
(
);
void
operator
()
(
void
operator
()
(
resizable_tensor
&
output
,
resizable_tensor
&
output
,
const
tensor
&
data
,
const
tensor
&
data
,
...
@@ -210,6 +222,15 @@ namespace dlib
...
@@ -210,6 +222,15 @@ namespace dlib
and adds this gradient to filters_gradient.
and adds this gradient to filters_gradient.
!*/
!*/
private:
void
*
filter_handle
;
void
*
conv_handle
;
// dimensions of the output tensor from operator()
int
out_num_samples
;
int
out_nr
;
int
out_nc
;
int
out_k
;
};
};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment