Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
7550681b
Commit
7550681b
authored
Dec 07, 2015
by
Davis King
Browse files
Implemented the bn layer.
parent
363b6b2f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
3 deletions
+44
-3
dlib/dnn/layers.h
dlib/dnn/layers.h
+44
-3
No files found.
dlib/dnn/layers.h
View file @
7550681b
...
@@ -187,27 +187,68 @@ namespace dlib
...
@@ -187,27 +187,68 @@ namespace dlib
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
)
void
setup
(
const
SUBNET
&
sub
)
{
{
// TODO
gamma
=
alias_tensor
(
1
,
sub
.
get_output
().
k
(),
sub
.
get_output
().
nr
(),
sub
.
get_output
().
nc
());
beta
=
gamma
;
params
.
set_size
(
gamma
.
size
()
+
beta
.
size
());
gamma
(
params
,
0
)
=
1
;
beta
(
params
,
gamma
.
size
())
=
0
;
}
}
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
{
{
// TODO
auto
g
=
gamma
(
params
,
0
);
auto
b
=
beta
(
params
,
gamma
.
size
());
tt
::
batch_normalize
(
output
,
means
,
invstds
,
sub
.
get_output
(),
g
,
b
);
}
}
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
)
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
)
{
{
// TODO
auto
g
=
gamma
(
params
,
0
);
auto
g_grad
=
gamma
(
params_grad
,
0
);
auto
b_grad
=
beta
(
params_grad
,
gamma
.
size
());
bng
(
gradient_input
,
means
,
invstds
,
sub
.
get_output
(),
g
,
sub
.
get_gradient_input
(),
g_grad
,
b_grad
);
}
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
bn_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"bn_"
,
out
);
serialize
(
item
.
params
,
out
);
serialize
(
item
.
gamma
,
out
);
serialize
(
item
.
beta
,
out
);
serialize
(
item
.
means
,
out
);
serialize
(
item
.
invstds
,
out
);
}
friend
void
deserialize
(
bn_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"bn_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::bn_."
);
deserialize
(
item
.
params
,
in
);
deserialize
(
item
.
gamma
,
in
);
deserialize
(
item
.
beta
,
in
);
deserialize
(
item
.
means
,
in
);
deserialize
(
item
.
invstds
,
in
);
}
private:
private:
tt
::
batch_normalize_gradient
bng
;
resizable_tensor
params
;
resizable_tensor
params
;
alias_tensor
gamma
,
beta
;
resizable_tensor
means
;
resizable_tensor
invstds
;
};
};
template
<
typename
SUBNET
>
template
<
typename
SUBNET
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment