Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
b377f752
"src/runtime/vscode:/vscode.git/clone" did not exist on "5cf48fc69cbc1fe5325bd697a7d77215d2239403"
Commit
b377f752
authored
Jul 01, 2017
by
Davis King
Browse files
Made it so you can set the number of output filters for con_ layers at runtime.
parent
8eb9e295
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
20 deletions
+32
-20
dlib/dnn/layers.h
dlib/dnn/layers.h
+32
-20
No files found.
dlib/dnn/layers.h
View file @
b377f752
...
...
@@ -21,6 +21,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
struct
num_con_outputs
{
num_con_outputs
(
unsigned
long
n
)
:
num_outputs
(
n
)
{}
unsigned
long
num_outputs
;
};
template
<
long
_num_filters
,
long
_nr
,
...
...
@@ -43,16 +49,22 @@ namespace dlib
static_assert
(
0
<=
_padding_x
&&
_padding_x
<
_nc
,
"The padding must be smaller than the filter size."
);
con_
(
num_con_outputs
o
)
:
learning_rate_multiplier
(
1
),
weight_decay_multiplier
(
1
),
bias_learning_rate_multiplier
(
1
),
bias_weight_decay_multiplier
(
0
),
padding_y_
(
_padding_y
),
padding_x_
(
_padding_x
)
{}
padding_x_
(
_padding_x
),
num_filters_
(
o
.
num_outputs
)
{
DLIB_CASSERT
(
num_filters_
>
0
);
}
long
num_filters
()
const
{
return
_num_filters
;
}
con_
()
:
con_
(
num_con_outputs
(
_num_filters
))
{}
long
num_filters
()
const
{
return
num_filters_
;
}
long
nr
()
const
{
return
_nr
;
}
long
nc
()
const
{
return
_nc
;
}
long
stride_y
()
const
{
return
_stride_y
;
}
...
...
@@ -60,6 +72,14 @@ namespace dlib
long
padding_y
()
const
{
return
padding_y_
;
}
long
padding_x
()
const
{
return
padding_x_
;
}
void
set_num_filters
(
long
num
)
{
DLIB_CASSERT
(
num
>
0
);
DLIB_CASSERT
(
get_layer_params
().
size
()
==
0
,
"You can't change the number of filters in con_ if the parameter tensor has already been allocated."
);
num_filters_
=
num
;
}
double
get_learning_rate_multiplier
()
const
{
return
learning_rate_multiplier
;
}
double
get_weight_decay_multiplier
()
const
{
return
weight_decay_multiplier
;
}
void
set_learning_rate_multiplier
(
double
val
)
{
learning_rate_multiplier
=
val
;
}
...
...
@@ -130,15 +150,15 @@ namespace dlib
void
setup
(
const
SUBNET
&
sub
)
{
long
num_inputs
=
_nr
*
_nc
*
sub
.
get_output
().
k
();
long
num_outputs
=
_
num_filters
;
long
num_outputs
=
num_filters
_
;
// allocate params for the filters and also for the filter bias values.
params
.
set_size
(
num_inputs
*
_
num_filters
+
_
num_filters
);
params
.
set_size
(
num_inputs
*
num_filters
_
+
num_filters
_
);
dlib
::
rand
rnd
(
std
::
rand
());
randomize_parameters
(
params
,
num_inputs
+
num_outputs
,
rnd
);
filters
=
alias_tensor
(
_
num_filters
,
sub
.
get_output
().
k
(),
_nr
,
_nc
);
biases
=
alias_tensor
(
1
,
_
num_filters
);
filters
=
alias_tensor
(
num_filters
_
,
sub
.
get_output
().
k
(),
_nr
,
_nc
);
biases
=
alias_tensor
(
1
,
num_filters
_
);
// set the initial bias values to zero
biases
(
params
,
filters
.
size
())
=
0
;
...
...
@@ -182,7 +202,7 @@ namespace dlib
{
serialize
(
"con_4"
,
out
);
serialize
(
item
.
params
,
out
);
serialize
(
_
num_filters
,
out
);
serialize
(
item
.
num_filters
_
,
out
);
serialize
(
_nr
,
out
);
serialize
(
_nc
,
out
);
serialize
(
_stride_y
,
out
);
...
...
@@ -201,7 +221,6 @@ namespace dlib
{
std
::
string
version
;
deserialize
(
version
,
in
);
long
num_filters
;
long
nr
;
long
nc
;
int
stride_y
;
...
...
@@ -209,7 +228,7 @@ namespace dlib
if
(
version
==
"con_4"
)
{
deserialize
(
item
.
params
,
in
);
deserialize
(
num_filters
,
in
);
deserialize
(
item
.
num_filters
_
,
in
);
deserialize
(
nr
,
in
);
deserialize
(
nc
,
in
);
deserialize
(
stride_y
,
in
);
...
...
@@ -224,14 +243,6 @@ namespace dlib
deserialize
(
item
.
bias_weight_decay_multiplier
,
in
);
if
(
item
.
padding_y_
!=
_padding_y
)
throw
serialization_error
(
"Wrong padding_y found while deserializing dlib::con_"
);
if
(
item
.
padding_x_
!=
_padding_x
)
throw
serialization_error
(
"Wrong padding_x found while deserializing dlib::con_"
);
if
(
num_filters
!=
_num_filters
)
{
std
::
ostringstream
sout
;
sout
<<
"Wrong num_filters found while deserializing dlib::con_"
<<
std
::
endl
;
sout
<<
"expected "
<<
_num_filters
<<
" but found "
<<
num_filters
<<
std
::
endl
;
throw
serialization_error
(
sout
.
str
());
}
if
(
nr
!=
_nr
)
throw
serialization_error
(
"Wrong nr found while deserializing dlib::con_"
);
if
(
nc
!=
_nc
)
throw
serialization_error
(
"Wrong nc found while deserializing dlib::con_"
);
if
(
stride_y
!=
_stride_y
)
throw
serialization_error
(
"Wrong stride_y found while deserializing dlib::con_"
);
...
...
@@ -247,7 +258,7 @@ namespace dlib
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
con_
&
item
)
{
out
<<
"con
\t
("
<<
"num_filters="
<<
_
num_filters
<<
"num_filters="
<<
item
.
num_filters
_
<<
", nr="
<<
_nr
<<
", nc="
<<
_nc
<<
", stride_y="
<<
_stride_y
...
...
@@ -265,7 +276,7 @@ namespace dlib
friend
void
to_xml
(
const
con_
&
item
,
std
::
ostream
&
out
)
{
out
<<
"<con"
<<
" num_filters='"
<<
_
num_filters
<<
"'"
<<
" num_filters='"
<<
item
.
num_filters
_
<<
"'"
<<
" nr='"
<<
_nr
<<
"'"
<<
" nc='"
<<
_nc
<<
"'"
<<
" stride_y='"
<<
_stride_y
<<
"'"
...
...
@@ -290,6 +301,7 @@ namespace dlib
double
weight_decay_multiplier
;
double
bias_learning_rate_multiplier
;
double
bias_weight_decay_multiplier
;
long
num_filters_
;
// These are here only because older versions of con (which you might encounter
// serialized to disk) used different padding settings.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment