Commit 0057461a authored by Davis King's avatar Davis King
Browse files

Promote some of the sub-network methods into the add_loss_layer interface so...

Promote some of the sub-network methods into the add_loss_layer interface so users don't have to write .subnet() so often.
parent c79f64f5
...@@ -2461,6 +2461,27 @@ namespace dlib ...@@ -2461,6 +2461,27 @@ namespace dlib
return results; return results;
} }
void back_propagate_error(const tensor& x)
{
subnet().back_propagate_error(x);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
{
subnet().back_propagate_error(x, gradient_input);
}
const tensor& get_final_data_gradient(
) const
{
return subnet().get_final_data_gradient();
}
const tensor& forward(const tensor& x)
{
return subnet().forward(x);
}
template <typename iterable_type> template <typename iterable_type>
std::vector<output_label_type> operator() ( std::vector<output_label_type> operator() (
const iterable_type& data, const iterable_type& data,
......
...@@ -857,6 +857,29 @@ namespace dlib ...@@ -857,6 +857,29 @@ namespace dlib
// ------------- // -------------
const tensor& forward(const tensor& x
);
/*!
requires
- sample_expansion_factor() != 0
(i.e. to_tensor() must have been called to set sample_expansion_factor()
to something non-zero.)
- x.num_samples()%sample_expansion_factor() == 0
- x.num_samples() > 0
ensures
- Runs x through the network and returns the results as a tensor. In particular,
this function just performs:
return subnet().forward(x);
So if you want to get the outputs as an output_label_type then call one of the
methods below instead, like operator().
- The return value from this function is also available in #subnet().get_output().
i.e. this function returns #subnet().get_output().
- have_same_dimensions(#subnet().get_gradient_input(), #subnet().get_output()) == true
- All elements of #subnet().get_gradient_input() are set to 0.
i.e. calling this function clears out #subnet().get_gradient_input() and ensures
it has the same dimensions as the most recent output.
!*/
template <typename output_iterator> template <typename output_iterator>
void operator() ( void operator() (
const tensor& x, const tensor& x,
...@@ -996,6 +1019,9 @@ namespace dlib ...@@ -996,6 +1019,9 @@ namespace dlib
- for all valid k: - for all valid k:
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()). - the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
- This function does not update the network parameters. - This function does not update the network parameters.
- For sub-layers that are immediate inputs into the loss layer, we also populate the
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
to the sub-layer's output.
!*/ !*/
template <typename forward_iterator, typename label_iterator> template <typename forward_iterator, typename label_iterator>
...@@ -1016,6 +1042,9 @@ namespace dlib ...@@ -1016,6 +1042,9 @@ namespace dlib
- for all valid k: - for all valid k:
- the expected label of *(ibegin+k) is *(lbegin+k). - the expected label of *(ibegin+k) is *(lbegin+k).
- This function does not update the network parameters. - This function does not update the network parameters.
- For sub-layers that are immediate inputs into the loss layer, we also populate the
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
to the sub-layer's output.
!*/ !*/
// ------------- // -------------
...@@ -1034,6 +1063,9 @@ namespace dlib ...@@ -1034,6 +1063,9 @@ namespace dlib
ensures ensures
- runs x through the network and returns the resulting loss. - runs x through the network and returns the resulting loss.
- This function does not update the network parameters. - This function does not update the network parameters.
- For sub-layers that are immediate inputs into the loss layer, we also populate the
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
to the sub-layer's output.
!*/ !*/
template <typename forward_iterator> template <typename forward_iterator>
...@@ -1049,6 +1081,9 @@ namespace dlib ...@@ -1049,6 +1081,9 @@ namespace dlib
ensures ensures
- runs [ibegin,iend) through the network and returns the resulting loss. - runs [ibegin,iend) through the network and returns the resulting loss.
- This function does not update the network parameters. - This function does not update the network parameters.
- For sub-layers that are immediate inputs into the loss layer, we also populate the
sub-layer's get_gradient_input() tensor with the gradient of the loss with respect
to the sub-layer's output.
!*/ !*/
// ------------- // -------------
...@@ -1163,12 +1198,72 @@ namespace dlib ...@@ -1163,12 +1198,72 @@ namespace dlib
!*/ !*/
template <typename solver_type> template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate) void update_parameters(std::vector<solver_type>& solvers, double learning_rate
{ update_parameters(make_sstack(solvers), learning_rate); } ) { update_parameters(make_sstack(solvers), learning_rate); }
/*! /*!
Convenience method for calling update_parameters() Convenience method for calling update_parameters()
!*/ !*/
void back_propagate_error(
const tensor& x
);
/*!
requires
- forward(x) was called to forward propagate x though the network.
Moreover, this was the most recent call to forward() and x has not been
subsequently modified in any way.
- subnet().get_gradient_input() has been set equal to the gradient of this network's
output with respect to the loss function (generally this will be done by calling
compute_loss()).
ensures
- Back propagates the error gradient, subnet().get_gradient_input(), through this
network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
!*/
void back_propagate_error(
const tensor& x,
const tensor& gradient_input
);
/*!
requires
- forward(x) was called to forward propagate x though the network.
Moreover, this was the most recent call to forward() and x has not been
subsequently modified in any way.
- have_same_dimensions(gradient_input, subnet().get_output()) == true
ensures
- This function is identical to the version of back_propagate_error()
defined immediately above except that it back-propagates gradient_input
through the network instead of subnet().get_gradient_input(). Therefore, this
version of back_propagate_error() is equivalent to performing:
subnet().get_gradient_input() = gradient_input;
back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient.
- All elements of #subnet.get_gradient_input() are set to 0.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
!*/
const tensor& get_final_data_gradient(
) const;
/*!
ensures
- if back_propagate_error() has been called to back-propagate a gradient
through this network then you can call get_final_data_gradient() to
obtain the last data gradient computed. That is, this function returns
the gradient of the network with respect to its inputs.
- Note that there is only one "final data gradient" for an entire network,
not one per layer, since there is only one input to the entire network.
!*/
// ------------- // -------------
void clean ( void clean (
......
...@@ -109,10 +109,9 @@ matrix<unsigned char> generate_image(generator_type& net, const noise_t& noise) ...@@ -109,10 +109,9 @@ matrix<unsigned char> generate_image(generator_type& net, const noise_t& noise)
return image; return image;
} }
std::vector<matrix<unsigned char>> get_generated_images(generator_type& net) std::vector<matrix<unsigned char>> get_generated_images(const tensor& out)
{ {
std::vector<matrix<unsigned char>> images; std::vector<matrix<unsigned char>> images;
const tensor& out = layer<1>(net).get_output();
for (size_t n = 0; n < out.num_samples(); ++n) for (size_t n = 0; n < out.num_samples(); ++n)
{ {
matrix<float> output = image_plane(out, n); matrix<float> output = image_plane(out, n);
...@@ -194,8 +193,8 @@ int main(int argc, char** argv) try ...@@ -194,8 +193,8 @@ int main(int argc, char** argv) try
// The following lines are equivalent to calling train_one_step(real_samples, real_labels) // The following lines are equivalent to calling train_one_step(real_samples, real_labels)
discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor); discriminator.to_tensor(real_samples.begin(), real_samples.end(), real_samples_tensor);
double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin()); double d_loss = discriminator.compute_loss(real_samples_tensor, real_labels.begin());
discriminator.subnet().back_propagate_error(real_samples_tensor); discriminator.back_propagate_error(real_samples_tensor);
discriminator.subnet().update_parameters(d_solvers, learning_rate); discriminator.update_parameters(d_solvers, learning_rate);
// Train the discriminator with fake images // Train the discriminator with fake images
// 1. generate some random noise // 1. generate some random noise
...@@ -204,17 +203,16 @@ int main(int argc, char** argv) try ...@@ -204,17 +203,16 @@ int main(int argc, char** argv) try
{ {
noises.push_back(make_noise(rnd)); noises.push_back(make_noise(rnd));
} }
// 2. forward the noise through the generator // 2. convert noises into a tensor
generator.to_tensor(noises.begin(), noises.end(), noises_tensor); generator.to_tensor(noises.begin(), noises.end(), noises_tensor);
generator.subnet().forward(noises_tensor); // 3. Then forward the noise through the network and convert the outputs into images.
// 3. get the generated images from the generator const auto fake_samples = get_generated_images(generator.forward(noises_tensor));
const auto fake_samples = get_generated_images(generator);
// 4. finally train the discriminator and wait for the threading to stop. The following // 4. finally train the discriminator and wait for the threading to stop. The following
// lines are equivalent to calling train_one_step(fake_samples, fake_labels) // lines are equivalent to calling train_one_step(fake_samples, fake_labels)
discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor); discriminator.to_tensor(fake_samples.begin(), fake_samples.end(), fake_samples_tensor);
d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin()); d_loss += discriminator.compute_loss(fake_samples_tensor, fake_labels.begin());
discriminator.subnet().back_propagate_error(fake_samples_tensor); discriminator.back_propagate_error(fake_samples_tensor);
discriminator.subnet().update_parameters(d_solvers, learning_rate); discriminator.update_parameters(d_solvers, learning_rate);
// Train the generator // Train the generator
// This part is the essence of the Generative Adversarial Networks. Until now, we have // This part is the essence of the Generative Adversarial Networks. Until now, we have
...@@ -227,11 +225,11 @@ int main(int argc, char** argv) try ...@@ -227,11 +225,11 @@ int main(int argc, char** argv) try
// Forward the fake samples and compute the loss with real labels // Forward the fake samples and compute the loss with real labels
const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin()); const auto g_loss = discriminator.compute_loss(fake_samples_tensor, real_labels.begin());
// Back propagate the error to fill the final data gradient // Back propagate the error to fill the final data gradient
discriminator.subnet().back_propagate_error(fake_samples_tensor); discriminator.back_propagate_error(fake_samples_tensor);
// Get the gradient that will tell the generator how to update itself // Get the gradient that will tell the generator how to update itself
const tensor& d_grad = discriminator.subnet().get_final_data_gradient(); const tensor& d_grad = discriminator.get_final_data_gradient();
generator.subnet().back_propagate_error(noises_tensor, d_grad); generator.back_propagate_error(noises_tensor, d_grad);
generator.subnet().update_parameters(g_solvers, learning_rate); generator.update_parameters(g_solvers, learning_rate);
// At some point, we should see that the generated images start looking like samples from // At some point, we should see that the generated images start looking like samples from
// the MNIST dataset // the MNIST dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment