Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
57bb5eb5
Unverified
Commit
57bb5eb5
authored
Mar 31, 2020
by
Adrià Arrufat
Committed by
GitHub
Mar 30, 2020
Browse files
use running stats to track losses (#2041)
parent
0057461a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
examples/dnn_dcgan_train_ex.cpp
examples/dnn_dcgan_train_ex.cpp
+8
-5
No files found.
examples/dnn_dcgan_train_ex.cpp
View file @
57bb5eb5
...
...
@@ -181,6 +181,7 @@ int main(int argc, char** argv) try
const
std
::
vector
<
float
>
fake_labels
(
minibatch_size
,
-
1
);
dlib
::
image_window
win
;
resizable_tensor
real_samples_tensor
,
fake_samples_tensor
,
noises_tensor
;
running_stats
<
double
>
g_loss
,
d_loss
;
while
(
iteration
<
50000
)
{
// Train the discriminator with real images
...
...
@@ -192,7 +193,7 @@ int main(int argc, char** argv) try
}
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
discriminator
.
to_tensor
(
real_samples
.
begin
(),
real_samples
.
end
(),
real_samples_tensor
);
d
ouble
d_loss
=
discriminator
.
compute_loss
(
real_samples_tensor
,
real_labels
.
begin
());
d
_loss
.
add
(
discriminator
.
compute_loss
(
real_samples_tensor
,
real_labels
.
begin
())
)
;
discriminator
.
back_propagate_error
(
real_samples_tensor
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
...
...
@@ -210,7 +211,7 @@ int main(int argc, char** argv) try
// 4. finally train the discriminator and wait for the threading to stop. The following
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
discriminator
.
to_tensor
(
fake_samples
.
begin
(),
fake_samples
.
end
(),
fake_samples_tensor
);
d_loss
+=
discriminator
.
compute_loss
(
fake_samples_tensor
,
fake_labels
.
begin
());
d_loss
.
add
(
discriminator
.
compute_loss
(
fake_samples_tensor
,
fake_labels
.
begin
())
)
;
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
...
...
@@ -223,7 +224,7 @@ int main(int argc, char** argv) try
// seen as test_one_step() plus the error back propagation.
// Forward the fake samples and compute the loss with real labels
const
auto
g_loss
=
discriminator
.
compute_loss
(
fake_samples_tensor
,
real_labels
.
begin
());
g_loss
.
add
(
discriminator
.
compute_loss
(
fake_samples_tensor
,
real_labels
.
begin
())
)
;
// Back propagate the error to fill the final data gradient
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
// Get the gradient that will tell the generator how to update itself
...
...
@@ -238,10 +239,12 @@ int main(int argc, char** argv) try
serialize
(
"dcgan_sync"
)
<<
generator
<<
discriminator
<<
iteration
;
std
::
cout
<<
"step#: "
<<
iteration
<<
"
\t
discriminator loss: "
<<
d_loss
<<
"
\t
generator loss: "
<<
g_loss
<<
'\n'
;
"
\t
discriminator loss: "
<<
d_loss
.
mean
()
*
2
<<
"
\t
generator loss: "
<<
g_loss
.
mean
()
<<
'\n'
;
win
.
set_image
(
tile_images
(
fake_samples
));
win
.
set_title
(
"DCGAN step#: "
+
to_string
(
iteration
));
d_loss
.
clear
();
g_loss
.
clear
();
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment