Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
00b2c22c
Commit
00b2c22c
authored
Oct 25, 2015
by
Davis King
Browse files
Implemented cuDNN based max_pool
parent
a07b31da
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
17 deletions
+118
-17
dlib/dnn/cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+83
-5
dlib/dnn/cudnn_dlibapi.h
dlib/dnn/cudnn_dlibapi.h
+35
-12
No files found.
dlib/dnn/cudnn_dlibapi.cpp
View file @
00b2c22c
...
...
@@ -503,17 +503,48 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_pool
::
max_pool
(
int
window_height
,
int
window_width
,
int
stride_y
,
int
stride_x
)
)
:
handle
(
nullptr
),
stride_y
(
0
),
stride_x
(
0
)
{
}
max_pool
::~
max_pool
(
)
{
clear
();
}
void
max_pool
::
clear
(
)
{
if
(
handle
)
cudnnDestroyPoolingDescriptor
((
cudnnPoolingDescriptor_t
)
handle
);
handle
=
nullptr
;
stride_y
=
0
;
stride_x
=
0
;
}
void
max_pool
::
setup
(
int
window_height
,
int
window_width
,
int
stride_y_
,
int
stride_x_
)
{
stride_x
=
stride_x_
;
stride_y
=
stride_y_
;
cudnnPoolingDescriptor_t
poolingDesc
;
check
(
cudnnCreatePoolingDescriptor
(
&
poolingDesc
));
handle
=
poolingDesc
;
check
(
cudnnSetPooling2dDescriptor
(
poolingDesc
,
CUDNN_POOLING_MAX
,
window_height
,
window_width
,
0
,
0
,
// no padding
stride_y
,
stride_x
));
}
void
max_pool
::
...
...
@@ -522,14 +553,61 @@ namespace dlib
const
tensor
&
src
)
{
const
float
alpha
=
1
;
const
float
beta
=
0
;
int
outN
;
int
outC
;
int
outH
;
int
outW
;
check
(
cudnnGetPooling2dForwardOutputDim
((
const
cudnnPoolingDescriptor_t
)
handle
,
descriptor
(
src
),
&
outN
,
&
outC
,
&
outH
,
&
outW
));
dest
.
set_size
(
outN
,
outC
,
outH
,
outW
);
DLIB_CASSERT
(
dest
.
num_samples
()
==
src
.
num_samples
(),
""
);
DLIB_CASSERT
(
dest
.
k
()
==
src
.
k
(),
""
);
DLIB_CASSERT
(
dest
.
nr
()
==
src
.
nr
()
/
stride_y
,
""
);
DLIB_CASSERT
(
dest
.
nc
()
==
src
.
nc
()
/
stride_x
,
""
);
check
(
cudnnPoolingForward
(
context
(),
(
const
cudnnPoolingDescriptor_t
)
handle
,
&
alpha
,
descriptor
(
src
),
src
.
device
(),
&
beta
,
descriptor
(
dest
),
dest
.
device
()));
}
void
max_pool
::
get_gradient
(
const
tensor
&
gradient_input
,
const
tensor
&
dest
,
const
tensor
&
src
,
tensor
&
grad
)
{
DLIB_CASSERT
(
have_same_dimensions
(
gradient_input
,
dest
),
""
);
DLIB_CASSERT
(
have_same_dimensions
(
src
,
grad
),
""
);
const
float
alpha
=
1
;
const
float
beta
=
0
;
check
(
cudnnPoolingBackward
(
context
(),
(
const
cudnnPoolingDescriptor_t
)
handle
,
&
alpha
,
descriptor
(
dest
),
dest
.
device
(),
descriptor
(
gradient_input
),
gradient_input
.
device
(),
descriptor
(
src
),
src
.
device
(),
&
beta
,
descriptor
(
grad
),
grad
.
device
()));
}
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/cudnn_dlibapi.h
View file @
00b2c22c
...
...
@@ -244,45 +244,68 @@ namespace dlib
class
max_pool
{
/*!
CUDNN_POOLING_MAX
!*/
public:
max_pool
(
const
max_pool
&
)
=
delete
;
max_pool
&
operator
=
(
const
max_pool
&
)
=
delete
;
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
max_pool
(
);
~
max_pool
(
);
void
clear
(
);
void
setup
(
int
window_height
,
int
window_width
,
int
stride_y
,
int
stride_x
);
// cudnnDestroyPoolingDescriptor ()
~
max_pool
(
);
// cudnnGetPooling2dForwardOutputDim(), cudnnPoolingForward()
void
operator
()
(
resizable_tensor
&
dest
,
const
tensor
&
src
);
/*!
ensures
- #dest.num_samples() == src.num_samples()
- #dest.k() == src.k()
- #dest.nr() == src.nr()/stride_y
- #dest.nc() == src.nc()/stride_x
- for all valid s, k, r, and c:
- image_plane(#dest,s,k)(r,c) == max(subm_clipped(image_plane(src,s,k),
r*stride_y,
c*stride_x,
window_height,
window_width))
!*/
// cudnnPoolingBackward()
void
get_gradient
(
const
tensor
&
gradient_input
,
const
tensor
&
dest
,
const
tensor
&
src
,
tensor
&
grad
);
/*!
- let OUT be the output of (*this)(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
requires
- have_same_dimensions(gradient_input,dest) == true
- have_same_dimensions(src,grad) == true
- dest contains the result of calling (*this)(dest,src)
ensures
- Recalling that dest is the output of (*this)(dest,src),
let f(src) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src
and adds it to grad.
!*/
private:
void
*
handle
;
int
stride_y
;
int
stride_x
;
};
// TODO, make the order of parameters of all these functions consistent.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment