Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
8e6d8ae0
Commit
8e6d8ae0
authored
Jun 22, 2016
by
Davis King
Browse files
Changed conv layer to use cross-correlation rather than convolution.
parent
595f0128
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
10 deletions
+55
-10
dlib/dnn/cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+12
-8
dlib/dnn/cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+1
-1
dlib/dnn/layers.h
dlib/dnn/layers.h
+42
-1
No files found.
dlib/dnn/cpu_dlib.cpp
View file @
8e6d8ae0
...
@@ -1631,9 +1631,11 @@ namespace dlib
...
@@ -1631,9 +1631,11 @@ namespace dlib
// now fill in the Toeplitz output matrix for the n-th sample in data.
// now fill in the Toeplitz output matrix for the n-th sample in data.
size_t
cnt
=
0
;
size_t
cnt
=
0
;
for
(
long
r
=
filter_nr
-
1
-
padding_y
;
r
-
padding_y
<
data
.
nr
();
r
+=
stride_y
)
const
long
max_r
=
data
.
nr
()
+
padding_y
-
(
filter_nr
-
1
);
const
long
max_c
=
data
.
nc
()
+
padding_x
-
(
filter_nc
-
1
);
for
(
long
r
=
-
padding_y
;
r
<
max_r
;
r
+=
stride_y
)
{
{
for
(
long
c
=
filter_nc
-
1
-
padding_x
;
c
-
padding_x
<
data
.
nc
()
;
c
+=
stride_x
)
for
(
long
c
=
-
padding_x
;
c
<
max_c
;
c
+=
stride_x
)
{
{
for
(
long
k
=
0
;
k
<
data
.
k
();
++
k
)
for
(
long
k
=
0
;
k
<
data
.
k
();
++
k
)
{
{
...
@@ -1642,8 +1644,8 @@ namespace dlib
...
@@ -1642,8 +1644,8 @@ namespace dlib
for
(
long
x
=
0
;
x
<
filter_nc
;
++
x
)
for
(
long
x
=
0
;
x
<
filter_nc
;
++
x
)
{
{
DLIB_ASSERT
(
cnt
<
output
.
size
(),
""
);
DLIB_ASSERT
(
cnt
<
output
.
size
(),
""
);
long
xx
=
c
-
x
;
long
xx
=
c
+
x
;
long
yy
=
r
-
y
;
long
yy
=
r
+
y
;
if
(
boundary
.
contains
(
xx
,
yy
))
if
(
boundary
.
contains
(
xx
,
yy
))
*
t
=
d
[(
k
*
data
.
nr
()
+
yy
)
*
data
.
nc
()
+
xx
];
*
t
=
d
[(
k
*
data
.
nr
()
+
yy
)
*
data
.
nc
()
+
xx
];
else
else
...
@@ -1676,9 +1678,11 @@ namespace dlib
...
@@ -1676,9 +1678,11 @@ namespace dlib
const
float
*
t
=
&
output
(
0
,
0
);
const
float
*
t
=
&
output
(
0
,
0
);
// now fill in the Toeplitz output matrix for the n-th sample in data.
// now fill in the Toeplitz output matrix for the n-th sample in data.
for
(
long
r
=
filter_nr
-
1
-
padding_y
;
r
-
padding_y
<
data
.
nr
();
r
+=
stride_y
)
const
long
max_r
=
data
.
nr
()
+
padding_y
-
(
filter_nr
-
1
);
const
long
max_c
=
data
.
nc
()
+
padding_x
-
(
filter_nc
-
1
);
for
(
long
r
=
-
padding_y
;
r
<
max_r
;
r
+=
stride_y
)
{
{
for
(
long
c
=
filter_nc
-
1
-
padding_x
;
c
-
padding_x
<
data
.
nc
()
;
c
+=
stride_x
)
for
(
long
c
=
-
padding_x
;
c
<
max_c
;
c
+=
stride_x
)
{
{
for
(
long
k
=
0
;
k
<
data
.
k
();
++
k
)
for
(
long
k
=
0
;
k
<
data
.
k
();
++
k
)
{
{
...
@@ -1686,8 +1690,8 @@ namespace dlib
...
@@ -1686,8 +1690,8 @@ namespace dlib
{
{
for
(
long
x
=
0
;
x
<
filter_nc
;
++
x
)
for
(
long
x
=
0
;
x
<
filter_nc
;
++
x
)
{
{
long
xx
=
c
-
x
;
long
xx
=
c
+
x
;
long
yy
=
r
-
y
;
long
yy
=
r
+
y
;
if
(
boundary
.
contains
(
xx
,
yy
))
if
(
boundary
.
contains
(
xx
,
yy
))
d
[(
k
*
data
.
nr
()
+
yy
)
*
data
.
nc
()
+
xx
]
+=
*
t
;
d
[(
k
*
data
.
nr
()
+
yy
)
*
data
.
nc
()
+
xx
]
+=
*
t
;
++
t
;
++
t
;
...
...
dlib/dnn/cudnn_dlibapi.cpp
View file @
8e6d8ae0
...
@@ -827,7 +827,7 @@ namespace dlib
...
@@ -827,7 +827,7 @@ namespace dlib
stride_y
,
stride_y
,
stride_x
,
stride_x
,
1
,
1
,
// must be 1,1
1
,
1
,
// must be 1,1
CUDNN_C
ONVOLU
TION
));
// could also be CUDNN_C
ROSS_CORRELA
TION
CUDNN_C
ROSS_CORRELA
TION
));
// could also be CUDNN_C
ONVOLU
TION
CHECK_CUDNN
(
cudnnGetConvolution2dForwardOutputDim
(
CHECK_CUDNN
(
cudnnGetConvolution2dForwardOutputDim
(
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
...
...
dlib/dnn/layers.h
View file @
8e6d8ae0
...
@@ -160,7 +160,7 @@ namespace dlib
...
@@ -160,7 +160,7 @@ namespace dlib
friend
void
serialize
(
const
con_
&
item
,
std
::
ostream
&
out
)
friend
void
serialize
(
const
con_
&
item
,
std
::
ostream
&
out
)
{
{
serialize
(
"con_
3
"
,
out
);
serialize
(
"con_
4
"
,
out
);
serialize
(
item
.
params
,
out
);
serialize
(
item
.
params
,
out
);
serialize
(
_num_filters
,
out
);
serialize
(
_num_filters
,
out
);
serialize
(
_nr
,
out
);
serialize
(
_nr
,
out
);
...
@@ -186,6 +186,33 @@ namespace dlib
...
@@ -186,6 +186,33 @@ namespace dlib
long
nc
;
long
nc
;
int
stride_y
;
int
stride_y
;
int
stride_x
;
int
stride_x
;
if
(
version
==
"con_4"
)
{
deserialize
(
item
.
params
,
in
);
deserialize
(
num_filters
,
in
);
deserialize
(
nr
,
in
);
deserialize
(
nc
,
in
);
deserialize
(
stride_y
,
in
);
deserialize
(
stride_x
,
in
);
deserialize
(
item
.
padding_y_
,
in
);
deserialize
(
item
.
padding_x_
,
in
);
deserialize
(
item
.
filters
,
in
);
deserialize
(
item
.
biases
,
in
);
deserialize
(
item
.
learning_rate_multiplier
,
in
);
deserialize
(
item
.
weight_decay_multiplier
,
in
);
deserialize
(
item
.
bias_learning_rate_multiplier
,
in
);
deserialize
(
item
.
bias_weight_decay_multiplier
,
in
);
if
(
item
.
padding_y_
!=
_padding_y
)
throw
serialization_error
(
"Wrong padding_y found while deserializing dlib::con_"
);
if
(
item
.
padding_x_
!=
_padding_x
)
throw
serialization_error
(
"Wrong padding_x found while deserializing dlib::con_"
);
if
(
num_filters
!=
_num_filters
)
throw
serialization_error
(
"Wrong num_filters found while deserializing dlib::con_"
);
if
(
nr
!=
_nr
)
throw
serialization_error
(
"Wrong nr found while deserializing dlib::con_"
);
if
(
nc
!=
_nc
)
throw
serialization_error
(
"Wrong nc found while deserializing dlib::con_"
);
if
(
stride_y
!=
_stride_y
)
throw
serialization_error
(
"Wrong stride_y found while deserializing dlib::con_"
);
if
(
stride_x
!=
_stride_x
)
throw
serialization_error
(
"Wrong stride_x found while deserializing dlib::con_"
);
return
;
}
if
(
version
==
"con_"
)
if
(
version
==
"con_"
)
{
{
deserialize
(
item
.
params
,
in
);
deserialize
(
item
.
params
,
in
);
...
@@ -237,6 +264,20 @@ namespace dlib
...
@@ -237,6 +264,20 @@ namespace dlib
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::con_."
);
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::con_."
);
}
}
// now flip all the filters
alias_tensor
at
(
_nr
,
_nc
);
size_t
off
=
0
;
for
(
int
i
=
0
;
i
<
item
.
filters
.
num_samples
();
++
i
)
{
for
(
int
j
=
0
;
j
<
item
.
filters
.
k
();
++
j
)
{
auto
temp
=
at
(
item
.
params
,
off
);
off
+=
_nr
*
_nc
;
temp
=
flipud
(
fliplr
(
mat
(
temp
)));
}
}
if
(
num_filters
!=
_num_filters
)
throw
serialization_error
(
"Wrong num_filters found while deserializing dlib::con_"
);
if
(
num_filters
!=
_num_filters
)
throw
serialization_error
(
"Wrong num_filters found while deserializing dlib::con_"
);
if
(
nr
!=
_nr
)
throw
serialization_error
(
"Wrong nr found while deserializing dlib::con_"
);
if
(
nr
!=
_nr
)
throw
serialization_error
(
"Wrong nr found while deserializing dlib::con_"
);
if
(
nc
!=
_nc
)
throw
serialization_error
(
"Wrong nc found while deserializing dlib::con_"
);
if
(
nc
!=
_nc
)
throw
serialization_error
(
"Wrong nc found while deserializing dlib::con_"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment