Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
10480cb9
Commit
10480cb9
authored
Oct 19, 2015
by
Davis King
Browse files
Hid the cuDNN context in a thread local variable so the user doesn't
need to deal with it.
parent
d63c4682
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
66 deletions
+28
-66
dlib/dnn/cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+24
-24
dlib/dnn/cudnn_dlibapi.h
dlib/dnn/cudnn_dlibapi.h
+4
-42
No files found.
dlib/dnn/cudnn_dlibapi.cpp
View file @
10480cb9
...
@@ -36,23 +36,36 @@ namespace dlib
...
@@ -36,23 +36,36 @@ namespace dlib
}
}
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
cudnn_context
::
cudnn_context
()
:
handle
(
nullptr
)
c
lass
c
udnn_context
{
{
cudnnHandle_t
h
;
public:
check
(
cudnnCreate
(
&
h
));
// not copyable
handle
=
h
;
cudnn_context
(
const
cudnn_context
&
)
=
delete
;
}
cudnn_context
&
operator
=
(
const
cudnn_context
&
)
=
delete
;
cudnn_context
::~
cudnn_context
()
cudnn_context
()
{
if
(
handle
)
{
{
cudnnDestroy
((
cudnnHandle_t
)
handle
);
check
(
cudnnCreate
(
&
handle
));
handle
=
nullptr
;
}
}
~
cudnn_context
()
{
cudnnDestroy
(
handle
);
}
cudnnHandle_t
get_handle
(
)
const
{
return
handle
;
}
private:
cudnnHandle_t
handle
;
};
static
cudnnHandle_t
context
()
{
thread_local
cudnn_context
c
;
return
c
.
get_handle
();
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
@@ -136,7 +149,6 @@ namespace dlib
...
@@ -136,7 +149,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
add
(
void
add
(
cudnn_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -146,7 +158,6 @@ namespace dlib
...
@@ -146,7 +158,6 @@ namespace dlib
}
}
void
set_tensor
(
void
set_tensor
(
cudnn_context
&
context
,
tensor
&
t
,
tensor
&
t
,
float
value
float
value
)
)
...
@@ -154,7 +165,6 @@ namespace dlib
...
@@ -154,7 +165,6 @@ namespace dlib
}
}
void
scale_tensor
(
void
scale_tensor
(
cudnn_context
&
context
,
tensor
&
t
,
tensor
&
t
,
float
value
float
value
)
)
...
@@ -194,7 +204,6 @@ namespace dlib
...
@@ -194,7 +204,6 @@ namespace dlib
void
conv
::
void
conv
::
setup
(
setup
(
cudnn_context
&
context
,
const
tensor
&
data
,
const
tensor
&
data
,
const
tensor
&
filters
,
const
tensor
&
filters
,
int
stride_y
,
int
stride_y
,
...
@@ -272,7 +281,6 @@ namespace dlib
...
@@ -272,7 +281,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
soft_max
(
void
soft_max
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
)
)
...
@@ -280,7 +288,6 @@ namespace dlib
...
@@ -280,7 +288,6 @@ namespace dlib
}
}
void
soft_max_gradient
(
void
soft_max_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -292,7 +299,6 @@ namespace dlib
...
@@ -292,7 +299,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
max_pool
::
max_pool
(
max_pool
::
max_pool
(
cudnn_context
&
context
,
int
window_height
,
int
window_height
,
int
window_width
,
int
window_width
,
int
stride_y
,
int
stride_y
,
...
@@ -326,7 +332,6 @@ namespace dlib
...
@@ -326,7 +332,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
sigmoid
(
void
sigmoid
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
)
)
...
@@ -334,7 +339,6 @@ namespace dlib
...
@@ -334,7 +339,6 @@ namespace dlib
}
}
void
sigmoid_gradient
(
void
sigmoid_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -345,7 +349,6 @@ namespace dlib
...
@@ -345,7 +349,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
relu
(
void
relu
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
)
)
...
@@ -353,7 +356,6 @@ namespace dlib
...
@@ -353,7 +356,6 @@ namespace dlib
}
}
void
relu_gradient
(
void
relu_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -364,7 +366,6 @@ namespace dlib
...
@@ -364,7 +366,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
tanh
(
void
tanh
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
)
)
...
@@ -372,7 +373,6 @@ namespace dlib
...
@@ -372,7 +373,6 @@ namespace dlib
}
}
void
tanh_gradient
(
void
tanh_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
...
dlib/dnn/cudnn_dlibapi.h
View file @
10480cb9
...
@@ -17,31 +17,6 @@ namespace dlib
...
@@ -17,31 +17,6 @@ namespace dlib
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
class
cudnn_context
{
public:
// not copyable
cudnn_context
(
const
cudnn_context
&
)
=
delete
;
cudnn_context
&
operator
=
(
const
cudnn_context
&
)
=
delete
;
// but is movable
cudnn_context
(
cudnn_context
&&
item
)
:
cudnn_context
()
{
swap
(
item
);
}
cudnn_context
&
operator
=
(
cudnn_context
&&
item
)
{
swap
(
item
);
return
*
this
;
}
cudnn_context
();
~
cudnn_context
();
const
void
*
get_handle
(
)
const
{
return
handle
;
}
private:
void
swap
(
cudnn_context
&
item
)
{
std
::
swap
(
handle
,
item
.
handle
);
}
void
*
handle
;
};
// ------------------------------------------------------------------------------------
class
tensor_descriptor
class
tensor_descriptor
{
{
/*!
/*!
...
@@ -91,7 +66,6 @@ namespace dlib
...
@@ -91,7 +66,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
add
(
void
add
(
cudnn_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -117,7 +91,6 @@ namespace dlib
...
@@ -117,7 +91,6 @@ namespace dlib
!*/
!*/
void
set_tensor
(
void
set_tensor
(
cudnn_context
&
context
,
tensor
&
t
,
tensor
&
t
,
float
value
float
value
);
);
...
@@ -128,7 +101,6 @@ namespace dlib
...
@@ -128,7 +101,6 @@ namespace dlib
!*/
!*/
void
scale_tensor
(
void
scale_tensor
(
cudnn_context
&
context
,
tensor
&
t
,
tensor
&
t
,
float
value
float
value
);
);
...
@@ -155,7 +127,6 @@ namespace dlib
...
@@ -155,7 +127,6 @@ namespace dlib
);
);
void
setup
(
void
setup
(
cudnn_context
&
context
,
const
tensor
&
data
,
const
tensor
&
data
,
const
tensor
&
filters
,
const
tensor
&
filters
,
int
stride_y
,
int
stride_y
,
...
@@ -236,7 +207,6 @@ namespace dlib
...
@@ -236,7 +207,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
soft_max
(
void
soft_max
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
);
);
...
@@ -245,13 +215,12 @@ namespace dlib
...
@@ -245,13 +215,12 @@ namespace dlib
!*/
!*/
void
soft_max_gradient
(
void
soft_max_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
);
);
/*!
/*!
- let OUT be the output of soft_max(
context,
OUT,src)
- let OUT be the output of soft_max(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src
- Then this function computes the gradient of f() with respect to src
and adds it to grad.
and adds it to grad.
...
@@ -271,7 +240,6 @@ namespace dlib
...
@@ -271,7 +240,6 @@ namespace dlib
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
max_pool
(
max_pool
(
cudnn_context
&
context
,
int
window_height
,
int
window_height
,
int
window_width
,
int
window_width
,
int
stride_y
,
int
stride_y
,
...
@@ -310,7 +278,6 @@ namespace dlib
...
@@ -310,7 +278,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID
// cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID
void
sigmoid
(
void
sigmoid
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
);
);
...
@@ -323,7 +290,6 @@ namespace dlib
...
@@ -323,7 +290,6 @@ namespace dlib
// cudnnActivationBackward()
// cudnnActivationBackward()
void
sigmoid_gradient
(
void
sigmoid_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -333,7 +299,7 @@ namespace dlib
...
@@ -333,7 +299,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
- have_same_dimensions(src,grad) == true
ensures
ensures
- let OUT be the output of sigmoid(
context,
OUT,src)
- let OUT be the output of sigmoid(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
adds it to grad.
...
@@ -343,7 +309,6 @@ namespace dlib
...
@@ -343,7 +309,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_RELU
// cudnnActivationForward(), CUDNN_ACTIVATION_RELU
void
relu
(
void
relu
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
);
);
...
@@ -356,7 +321,6 @@ namespace dlib
...
@@ -356,7 +321,6 @@ namespace dlib
// cudnnActivationBackward()
// cudnnActivationBackward()
void
relu_gradient
(
void
relu_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -366,7 +330,7 @@ namespace dlib
...
@@ -366,7 +330,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
- have_same_dimensions(src,grad) == true
ensures
ensures
- let OUT be the output of relu(
context,
OUT,src)
- let OUT be the output of relu(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
adds it to grad.
...
@@ -376,7 +340,6 @@ namespace dlib
...
@@ -376,7 +340,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_TANH
// cudnnActivationForward(), CUDNN_ACTIVATION_TANH
void
tanh
(
void
tanh
(
cudnn_context
&
context
,
resizable_tensor
&
dest
,
resizable_tensor
&
dest
,
const
tensor
&
src
const
tensor
&
src
);
);
...
@@ -389,7 +352,6 @@ namespace dlib
...
@@ -389,7 +352,6 @@ namespace dlib
// cudnnActivationBackward()
// cudnnActivationBackward()
void
tanh_gradient
(
void
tanh_gradient
(
cudnn_context
&
context
,
tensor
&
grad
,
tensor
&
grad
,
const
tensor
&
src
,
const
tensor
&
src
,
const
tensor
&
gradient_input
const
tensor
&
gradient_input
...
@@ -399,7 +361,7 @@ namespace dlib
...
@@ -399,7 +361,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
- have_same_dimensions(src,grad) == true
ensures
ensures
- let OUT be the output of tanh(
context,
OUT,src)
- let OUT be the output of tanh(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
adds it to grad.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment