Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
57bb5eb5
"docs/vscode:/vscode.git/clone" did not exist on "c99f42371884347294972b127ba0a3de91563296"
Unverified
Commit
57bb5eb5
authored
Mar 31, 2020
by
Adrià Arrufat
Committed by
GitHub
Mar 30, 2020
Browse files
use running stats to track losses (#2041)
parent
0057461a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
5 deletions
+8
-5
examples/dnn_dcgan_train_ex.cpp
examples/dnn_dcgan_train_ex.cpp
+8
-5
No files found.
examples/dnn_dcgan_train_ex.cpp
View file @
57bb5eb5
...
@@ -181,6 +181,7 @@ int main(int argc, char** argv) try
...
@@ -181,6 +181,7 @@ int main(int argc, char** argv) try
const
std
::
vector
<
float
>
fake_labels
(
minibatch_size
,
-
1
);
const
std
::
vector
<
float
>
fake_labels
(
minibatch_size
,
-
1
);
dlib
::
image_window
win
;
dlib
::
image_window
win
;
resizable_tensor
real_samples_tensor
,
fake_samples_tensor
,
noises_tensor
;
resizable_tensor
real_samples_tensor
,
fake_samples_tensor
,
noises_tensor
;
running_stats
<
double
>
g_loss
,
d_loss
;
while
(
iteration
<
50000
)
while
(
iteration
<
50000
)
{
{
// Train the discriminator with real images
// Train the discriminator with real images
...
@@ -192,7 +193,7 @@ int main(int argc, char** argv) try
...
@@ -192,7 +193,7 @@ int main(int argc, char** argv) try
}
}
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
// The following lines are equivalent to calling train_one_step(real_samples, real_labels)
discriminator
.
to_tensor
(
real_samples
.
begin
(),
real_samples
.
end
(),
real_samples_tensor
);
discriminator
.
to_tensor
(
real_samples
.
begin
(),
real_samples
.
end
(),
real_samples_tensor
);
d
ouble
d_loss
=
discriminator
.
compute_loss
(
real_samples_tensor
,
real_labels
.
begin
());
d
_loss
.
add
(
discriminator
.
compute_loss
(
real_samples_tensor
,
real_labels
.
begin
())
)
;
discriminator
.
back_propagate_error
(
real_samples_tensor
);
discriminator
.
back_propagate_error
(
real_samples_tensor
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
...
@@ -210,7 +211,7 @@ int main(int argc, char** argv) try
...
@@ -210,7 +211,7 @@ int main(int argc, char** argv) try
// 4. finally train the discriminator and wait for the threading to stop. The following
// 4. finally train the discriminator and wait for the threading to stop. The following
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
// lines are equivalent to calling train_one_step(fake_samples, fake_labels)
discriminator
.
to_tensor
(
fake_samples
.
begin
(),
fake_samples
.
end
(),
fake_samples_tensor
);
discriminator
.
to_tensor
(
fake_samples
.
begin
(),
fake_samples
.
end
(),
fake_samples_tensor
);
d_loss
+=
discriminator
.
compute_loss
(
fake_samples_tensor
,
fake_labels
.
begin
());
d_loss
.
add
(
discriminator
.
compute_loss
(
fake_samples_tensor
,
fake_labels
.
begin
())
)
;
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
discriminator
.
update_parameters
(
d_solvers
,
learning_rate
);
...
@@ -223,7 +224,7 @@ int main(int argc, char** argv) try
...
@@ -223,7 +224,7 @@ int main(int argc, char** argv) try
// seen as test_one_step() plus the error back propagation.
// seen as test_one_step() plus the error back propagation.
// Forward the fake samples and compute the loss with real labels
// Forward the fake samples and compute the loss with real labels
const
auto
g_loss
=
discriminator
.
compute_loss
(
fake_samples_tensor
,
real_labels
.
begin
());
g_loss
.
add
(
discriminator
.
compute_loss
(
fake_samples_tensor
,
real_labels
.
begin
())
)
;
// Back propagate the error to fill the final data gradient
// Back propagate the error to fill the final data gradient
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
discriminator
.
back_propagate_error
(
fake_samples_tensor
);
// Get the gradient that will tell the generator how to update itself
// Get the gradient that will tell the generator how to update itself
...
@@ -238,10 +239,12 @@ int main(int argc, char** argv) try
...
@@ -238,10 +239,12 @@ int main(int argc, char** argv) try
serialize
(
"dcgan_sync"
)
<<
generator
<<
discriminator
<<
iteration
;
serialize
(
"dcgan_sync"
)
<<
generator
<<
discriminator
<<
iteration
;
std
::
cout
<<
std
::
cout
<<
"step#: "
<<
iteration
<<
"step#: "
<<
iteration
<<
"
\t
discriminator loss: "
<<
d_loss
<<
"
\t
discriminator loss: "
<<
d_loss
.
mean
()
*
2
<<
"
\t
generator loss: "
<<
g_loss
<<
'\n'
;
"
\t
generator loss: "
<<
g_loss
.
mean
()
<<
'\n'
;
win
.
set_image
(
tile_images
(
fake_samples
));
win
.
set_image
(
tile_images
(
fake_samples
));
win
.
set_title
(
"DCGAN step#: "
+
to_string
(
iteration
));
win
.
set_title
(
"DCGAN step#: "
+
to_string
(
iteration
));
d_loss
.
clear
();
g_loss
.
clear
();
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment