Commit c79f64f5 authored by Davis King's avatar Davis King
Browse files

make update_parameters() a little more uniform

parent fd014534
...@@ -1003,7 +1003,7 @@ namespace dlib ...@@ -1003,7 +1003,7 @@ namespace dlib
template <typename solver_type> template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate) void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ {
subnetwork->update_parameters(make_sstack(solvers), learning_rate); update_parameters(make_sstack(solvers), learning_rate);
} }
const tensor& get_parameter_gradient( const tensor& get_parameter_gradient(
...@@ -1369,6 +1369,12 @@ namespace dlib ...@@ -1369,6 +1369,12 @@ namespace dlib
} }
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient( const tensor& get_parameter_gradient(
) const { return params_grad; } ) const { return params_grad; }
...@@ -1609,6 +1615,12 @@ namespace dlib ...@@ -1609,6 +1615,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate); subnetwork.update_parameters(solvers, learning_rate);
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient( const tensor& get_parameter_gradient(
) const { return params_grad; } ) const { return params_grad; }
...@@ -1905,6 +1917,12 @@ namespace dlib ...@@ -1905,6 +1917,12 @@ namespace dlib
subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate); subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate);
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
...@@ -2135,6 +2153,12 @@ namespace dlib ...@@ -2135,6 +2153,12 @@ namespace dlib
// nothing to do // nothing to do
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return input_layer; } const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; } subnet_type& subnet() { return input_layer; }
...@@ -2550,6 +2574,12 @@ namespace dlib ...@@ -2550,6 +2574,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate); subnetwork.update_parameters(solvers, learning_rate);
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
const loss_details_type& loss_details() const { return loss; } const loss_details_type& loss_details() const { return loss; }
...@@ -2940,6 +2970,12 @@ namespace dlib ...@@ -2940,6 +2970,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate); subnetwork.update_parameters(solvers, learning_rate);
} }
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient( const tensor& get_parameter_gradient(
) const { return params_grad; } ) const { return params_grad; }
......
...@@ -639,6 +639,13 @@ namespace dlib ...@@ -639,6 +639,13 @@ namespace dlib
- The solvers use the given learning rate. - The solvers use the given learning rate.
!*/ !*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
void clean( void clean(
); );
/*! /*!
...@@ -1155,6 +1162,13 @@ namespace dlib ...@@ -1155,6 +1162,13 @@ namespace dlib
- The solvers use the given learning rate. - The solvers use the given learning rate.
!*/ !*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
// ------------- // -------------
void clean ( void clean (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment