"vscode:/vscode.git/clone" did not exist on "05d597fa95d3b3b61a59cecc64006dd723e4890c"
Commit c79f64f5 authored by Davis King's avatar Davis King
Browse files

make update_parameters() a little more uniform

parent fd014534
......@@ -1003,7 +1003,7 @@ namespace dlib
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
subnetwork->update_parameters(make_sstack(solvers), learning_rate);
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
......@@ -1369,6 +1369,12 @@ namespace dlib
}
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }
......@@ -1609,6 +1615,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }
......@@ -1905,6 +1917,12 @@ namespace dlib
subnetwork.update_parameters(solvers.pop(comp_layers_in_each_group*details.size()),learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
......@@ -2135,6 +2153,12 @@ namespace dlib
// nothing to do
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; }
......@@ -2550,6 +2574,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const loss_details_type& loss_details() const { return loss; }
......@@ -2940,6 +2970,12 @@ namespace dlib
subnetwork.update_parameters(solvers, learning_rate);
}
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{
update_parameters(make_sstack(solvers), learning_rate);
}
const tensor& get_parameter_gradient(
) const { return params_grad; }
......
......@@ -639,6 +639,13 @@ namespace dlib
- The solvers use the given learning rate.
!*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
void clean(
);
/*!
......@@ -1155,6 +1162,13 @@ namespace dlib
- The solvers use the given learning rate.
!*/
template <typename solver_type>
void update_parameters(std::vector<solver_type>& solvers, double learning_rate)
{ update_parameters(make_sstack(solvers), learning_rate); }
/*!
Convenience method for calling update_parameters()
!*/
// -------------
void clean (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment