You need to sign in or sign up before continuing.
Commit e5ad9590 authored by Davis King's avatar Davis King
Browse files

Added bias learning rate and weight decay multipliers to bn_ layers

parent b6b83798
......@@ -666,6 +666,8 @@ namespace dlib
running_stats_window_size(window_size),
learning_rate_multiplier(1),
weight_decay_multiplier(0),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1),
eps(eps_)
{}
......@@ -680,6 +682,11 @@ namespace dlib
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
template <typename SUBNET>
void setup (const SUBNET& sub)
......@@ -765,6 +772,8 @@ namespace dlib
serialize(item.running_stats_window_size, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.eps, out);
}
......@@ -812,6 +821,8 @@ namespace dlib
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
deserialize(item.eps, in);
}
else
......@@ -834,6 +845,8 @@ namespace dlib
out << " eps="<<item.eps;
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
return out;
}
......@@ -849,6 +862,8 @@ namespace dlib
unsigned long running_stats_window_size;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
double eps;
};
......
......@@ -856,9 +856,11 @@ namespace dlib
/*!
ensures
- #get_mode() == mode
- #get_running_stats_window_size() == 1000
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_running_stats_window_size() == 1000
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == tt::DEFAULT_BATCH_NORM_EPS
!*/
......@@ -871,9 +873,11 @@ namespace dlib
- eps > 0
ensures
- #get_mode() == mode
- #get_running_stats_window_size() == window_size
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_running_stats_window_size() == window_size
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == eps
!*/
......@@ -953,6 +957,44 @@ namespace dlib
- #get_weight_decay_multiplier() == val
!*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
......
......@@ -89,6 +89,17 @@ namespace dlib
return v;
}
template < layer_mode mode >
const tensor& operator() (
const float learning_rate,
const bn_<mode>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
return v;
}
friend void serialize(const sgd& item, std::ostream& out)
{
serialize("sgd2", out);
......@@ -244,6 +255,17 @@ namespace dlib
return s;
}
template < layer_mode mode >
const tensor& operator() (
const float learning_rate,
const bn_<mode>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
return s;
}
friend void serialize(const adam& item, std::ostream& out)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment