Unverified Commit e7ec6b77 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add visitor to remove bias from bn_ layer inputs (#closes 2155) (#2156)

* add visitor to remove bias from bn_ inputs (#closes 2155)

* remove unused parameter and make documentation more clear

* remove bias from bn_ layers too and use better name

* let the batch norm keep their bias, use even better name

* be more consistent with impl naming

* remove default constructor

* do not use method to prevent some errors

* add disable bias method to pertinent layers

* update dcgan example

- grammar
- print number of network parameters to be able to check bias is not allocated
- at the end, give feedback to the user about what the discriminator thinks about each generated sample

* fix fc_ logic

* add documentation

* add bias_is_disabled methods and update to_xml

* print use_bias=false when bias is disabled
parent ed22f040
...@@ -183,6 +183,28 @@ namespace dlib ...@@ -183,6 +183,28 @@ namespace dlib
impl::set_bias_weight_decay_multiplier(obj, special_(), bias_weight_decay_multiplier); impl::set_bias_weight_decay_multiplier(obj, special_(), bias_weight_decay_multiplier);
} }
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::disable_bias)>::type = 0>
void disable_bias(
T& obj,
special_
) { obj.disable_bias(); }
template <typename T>
void disable_bias( const T& , general_) { }
}
template <typename T>
void disable_bias(
T& obj
)
{
impl::disable_bias(obj, special_());
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
......
...@@ -157,6 +157,20 @@ namespace dlib ...@@ -157,6 +157,20 @@ namespace dlib
- does nothing - does nothing
!*/ !*/
// ----------------------------------------------------------------------------------------
template <typename T>
void disable_bias(
T& obj
);
/*!
ensures
- if (obj has a disable_bias() member function) then
- calls obj.disable_bias()
- else
- does nothing
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
bool dnn_prefer_fastest_algorithms( bool dnn_prefer_fastest_algorithms(
......
...@@ -59,7 +59,8 @@ namespace dlib ...@@ -59,7 +59,8 @@ namespace dlib
bias_weight_decay_multiplier(0), bias_weight_decay_multiplier(0),
num_filters_(o.num_outputs), num_filters_(o.num_outputs),
padding_y_(_padding_y), padding_y_(_padding_y),
padding_x_(_padding_x) padding_x_(_padding_x),
use_bias(true)
{ {
DLIB_CASSERT(num_filters_ > 0); DLIB_CASSERT(num_filters_ > 0);
} }
...@@ -106,6 +107,8 @@ namespace dlib ...@@ -106,6 +107,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
inline dpoint map_input_to_output ( inline dpoint map_input_to_output (
dpoint p dpoint p
...@@ -137,7 +140,8 @@ namespace dlib ...@@ -137,7 +140,8 @@ namespace dlib
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
num_filters_(item.num_filters_), num_filters_(item.num_filters_),
padding_y_(item.padding_y_), padding_y_(item.padding_y_),
padding_x_(item.padding_x_) padding_x_(item.padding_x_),
use_bias(item.use_bias)
{ {
// this->conv is non-copyable and basically stateless, so we have to write our // this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error. // own copy to avoid trying to copy it and getting an error.
...@@ -162,6 +166,7 @@ namespace dlib ...@@ -162,6 +166,7 @@ namespace dlib
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_; num_filters_ = item.num_filters_;
use_bias = item.use_bias;
return *this; return *this;
} }
...@@ -174,16 +179,18 @@ namespace dlib ...@@ -174,16 +179,18 @@ namespace dlib
long num_inputs = filt_nr*filt_nc*sub.get_output().k(); long num_inputs = filt_nr*filt_nc*sub.get_output().k();
long num_outputs = num_filters_; long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values. // allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*num_filters_ + num_filters_); params.set_size(num_inputs*num_filters_ + static_cast<int>(use_bias) * num_filters_);
dlib::rand rnd(std::rand()); dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd); randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc); filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc);
biases = alias_tensor(1,num_filters_); if (use_bias)
{
// set the initial bias values to zero biases = alias_tensor(1,num_filters_);
biases(params,filters.size()) = 0; // set the initial bias values to zero
biases(params,filters.size()) = 0;
}
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -198,9 +205,11 @@ namespace dlib ...@@ -198,9 +205,11 @@ namespace dlib
conv(false, output, conv(false, output,
sub.get_output(), sub.get_output(),
filters(params,0)); filters(params,0));
if (use_bias)
tt::add(1,output,1,biases(params,filters.size())); {
} tt::add(1,output,1,biases(params,filters.size()));
}
}
template <typename SUBNET> template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
...@@ -211,8 +220,11 @@ namespace dlib ...@@ -211,8 +220,11 @@ namespace dlib
{ {
auto filt = filters(params_grad,0); auto filt = filters(params_grad,0);
conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt); conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt);
auto b = biases(params_grad, filters.size()); if (use_bias)
tt::assign_conv_bias_gradient(b, gradient_input); {
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
}
} }
} }
...@@ -221,7 +233,7 @@ namespace dlib ...@@ -221,7 +233,7 @@ namespace dlib
friend void serialize(const con_& item, std::ostream& out) friend void serialize(const con_& item, std::ostream& out)
{ {
serialize("con_4", out); serialize("con_5", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.num_filters_, out); serialize(item.num_filters_, out);
serialize(_nr, out); serialize(_nr, out);
...@@ -236,6 +248,7 @@ namespace dlib ...@@ -236,6 +248,7 @@ namespace dlib
serialize(item.weight_decay_multiplier, out); serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out); serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
} }
friend void deserialize(con_& item, std::istream& in) friend void deserialize(con_& item, std::istream& in)
...@@ -246,7 +259,7 @@ namespace dlib ...@@ -246,7 +259,7 @@ namespace dlib
long nc; long nc;
int stride_y; int stride_y;
int stride_x; int stride_x;
if (version == "con_4") if (version == "con_4" || version == "con_5")
{ {
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.num_filters_, in); deserialize(item.num_filters_, in);
...@@ -268,6 +281,10 @@ namespace dlib ...@@ -268,6 +281,10 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
if (version == "con_5")
{
deserialize(item.use_bias, in);
}
} }
else else
{ {
...@@ -289,8 +306,15 @@ namespace dlib ...@@ -289,8 +306,15 @@ namespace dlib
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier; out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier; out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier; if (item.use_bias)
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier; {
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
return out; return out;
} }
...@@ -307,7 +331,8 @@ namespace dlib ...@@ -307,7 +331,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'" << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'" << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'" << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n"; << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << mat(item.params); out << mat(item.params);
out << "</con>"; out << "</con>";
} }
...@@ -328,6 +353,7 @@ namespace dlib ...@@ -328,6 +353,7 @@ namespace dlib
// serialized to disk) used different padding settings. // serialized to disk) used different padding settings.
int padding_y_; int padding_y_;
int padding_x_; int padding_x_;
bool use_bias;
}; };
...@@ -373,7 +399,8 @@ namespace dlib ...@@ -373,7 +399,8 @@ namespace dlib
bias_weight_decay_multiplier(0), bias_weight_decay_multiplier(0),
num_filters_(o.num_outputs), num_filters_(o.num_outputs),
padding_y_(_padding_y), padding_y_(_padding_y),
padding_x_(_padding_x) padding_x_(_padding_x),
use_bias(true)
{ {
DLIB_CASSERT(num_filters_ > 0); DLIB_CASSERT(num_filters_ > 0);
} }
...@@ -408,6 +435,8 @@ namespace dlib ...@@ -408,6 +435,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
inline dpoint map_output_to_input ( inline dpoint map_output_to_input (
dpoint p dpoint p
...@@ -439,7 +468,8 @@ namespace dlib ...@@ -439,7 +468,8 @@ namespace dlib
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier), bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
num_filters_(item.num_filters_), num_filters_(item.num_filters_),
padding_y_(item.padding_y_), padding_y_(item.padding_y_),
padding_x_(item.padding_x_) padding_x_(item.padding_x_),
use_bias(item.use_bias)
{ {
// this->conv is non-copyable and basically stateless, so we have to write our // this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error. // own copy to avoid trying to copy it and getting an error.
...@@ -464,6 +494,7 @@ namespace dlib ...@@ -464,6 +494,7 @@ namespace dlib
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier; bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier; bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
num_filters_ = item.num_filters_; num_filters_ = item.num_filters_;
use_bias = item.use_bias;
return *this; return *this;
} }
...@@ -473,16 +504,18 @@ namespace dlib ...@@ -473,16 +504,18 @@ namespace dlib
long num_inputs = _nr*_nc*sub.get_output().k(); long num_inputs = _nr*_nc*sub.get_output().k();
long num_outputs = num_filters_; long num_outputs = num_filters_;
// allocate params for the filters and also for the filter bias values. // allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*num_filters_ + num_filters_); params.set_size(num_inputs*num_filters_ + num_filters_ * static_cast<int>(use_bias));
dlib::rand rnd(std::rand()); dlib::rand rnd(std::rand());
randomize_parameters(params, num_inputs+num_outputs, rnd); randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc); filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
biases = alias_tensor(1,num_filters_); if (use_bias)
{
// set the initial bias values to zero biases = alias_tensor(1,num_filters_);
biases(params,filters.size()) = 0; // set the initial bias values to zero
biases(params,filters.size()) = 0;
}
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -496,7 +529,10 @@ namespace dlib ...@@ -496,7 +529,10 @@ namespace dlib
output.set_size(gnsamps,gk,gnr,gnc); output.set_size(gnsamps,gk,gnr,gnc);
conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_); conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_);
conv.get_gradient_for_data(false, sub.get_output(),filt,output); conv.get_gradient_for_data(false, sub.get_output(),filt,output);
tt::add(1,output,1,biases(params,filters.size())); if (use_bias)
{
tt::add(1,output,1,biases(params,filters.size()));
}
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -509,8 +545,11 @@ namespace dlib ...@@ -509,8 +545,11 @@ namespace dlib
{ {
auto filt = filters(params_grad,0); auto filt = filters(params_grad,0);
conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt); conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt);
auto b = biases(params_grad, filters.size()); if (use_bias)
tt::assign_conv_bias_gradient(b, gradient_input); {
auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input);
}
} }
} }
...@@ -519,7 +558,7 @@ namespace dlib ...@@ -519,7 +558,7 @@ namespace dlib
friend void serialize(const cont_& item, std::ostream& out) friend void serialize(const cont_& item, std::ostream& out)
{ {
serialize("cont_1", out); serialize("cont_2", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.num_filters_, out); serialize(item.num_filters_, out);
serialize(_nr, out); serialize(_nr, out);
...@@ -534,6 +573,7 @@ namespace dlib ...@@ -534,6 +573,7 @@ namespace dlib
serialize(item.weight_decay_multiplier, out); serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out); serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
} }
friend void deserialize(cont_& item, std::istream& in) friend void deserialize(cont_& item, std::istream& in)
...@@ -544,7 +584,7 @@ namespace dlib ...@@ -544,7 +584,7 @@ namespace dlib
long nc; long nc;
int stride_y; int stride_y;
int stride_x; int stride_x;
if (version == "cont_1") if (version == "cont_1" || version == "cont_2")
{ {
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.num_filters_, in); deserialize(item.num_filters_, in);
...@@ -566,6 +606,10 @@ namespace dlib ...@@ -566,6 +606,10 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_"); if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_"); if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_"); if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
if (version == "cont_2")
{
deserialize(item.use_bias, in);
}
} }
else else
{ {
...@@ -587,8 +631,15 @@ namespace dlib ...@@ -587,8 +631,15 @@ namespace dlib
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier; out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier; out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier; if (item.use_bias)
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier; {
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
return out; return out;
} }
...@@ -605,7 +656,8 @@ namespace dlib ...@@ -605,7 +656,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'" << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'" << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'" << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'>\n"; << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << mat(item.params); out << mat(item.params);
out << "</cont>"; out << "</cont>";
} }
...@@ -625,6 +677,8 @@ namespace dlib ...@@ -625,6 +677,8 @@ namespace dlib
int padding_y_; int padding_y_;
int padding_x_; int padding_x_;
bool use_bias;
}; };
template < template <
...@@ -1522,6 +1576,37 @@ namespace dlib ...@@ -1522,6 +1576,37 @@ namespace dlib
unsigned long new_window_size; unsigned long new_window_size;
}; };
class visitor_bn_input_no_bias
{
public:
template <typename T>
void set_input_no_bias(T&) const
{
// ignore other layer types
}
template <layer_mode mode, typename U, typename E>
void set_input_no_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
set_input_no_bias(l);
}
};
} }
template <typename net_type> template <typename net_type>
...@@ -1533,6 +1618,14 @@ namespace dlib ...@@ -1533,6 +1618,14 @@ namespace dlib
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size)); visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
} }
template <typename net_type>
void set_all_bn_inputs_no_bias (
net_type& net
)
{
visit_layers(net, impl::visitor_bn_input_no_bias());
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1561,7 +1654,8 @@ namespace dlib ...@@ -1561,7 +1654,8 @@ namespace dlib
learning_rate_multiplier(1), learning_rate_multiplier(1),
weight_decay_multiplier(1), weight_decay_multiplier(1),
bias_learning_rate_multiplier(1), bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0) bias_weight_decay_multiplier(0),
use_bias(true)
{} {}
fc_() : fc_(num_fc_outputs(num_outputs_)) {} fc_() : fc_(num_fc_outputs(num_outputs_)) {}
...@@ -1575,6 +1669,8 @@ namespace dlib ...@@ -1575,6 +1669,8 @@ namespace dlib
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; } double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; } void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; } void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
void disable_bias() { use_bias = false; }
bool bias_is_disabled() const { return !use_bias; }
unsigned long get_num_outputs ( unsigned long get_num_outputs (
) const { return num_outputs; } ) const { return num_outputs; }
...@@ -1597,7 +1693,7 @@ namespace dlib ...@@ -1597,7 +1693,7 @@ namespace dlib
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
{ {
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(); num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
if (bias_mode == FC_HAS_BIAS) if (bias_mode == FC_HAS_BIAS && use_bias)
params.set_size(num_inputs+1, num_outputs); params.set_size(num_inputs+1, num_outputs);
else else
params.set_size(num_inputs, num_outputs); params.set_size(num_inputs, num_outputs);
...@@ -1607,7 +1703,7 @@ namespace dlib ...@@ -1607,7 +1703,7 @@ namespace dlib
weights = alias_tensor(num_inputs, num_outputs); weights = alias_tensor(num_inputs, num_outputs);
if (bias_mode == FC_HAS_BIAS) if (bias_mode == FC_HAS_BIAS && use_bias)
{ {
biases = alias_tensor(1,num_outputs); biases = alias_tensor(1,num_outputs);
// set the initial bias values to zero // set the initial bias values to zero
...@@ -1624,7 +1720,7 @@ namespace dlib ...@@ -1624,7 +1720,7 @@ namespace dlib
auto w = weights(params, 0); auto w = weights(params, 0);
tt::gemm(0,output, 1,sub.get_output(),false, w,false); tt::gemm(0,output, 1,sub.get_output(),false, w,false);
if (bias_mode == FC_HAS_BIAS) if (bias_mode == FC_HAS_BIAS && use_bias)
{ {
auto b = biases(params, weights.size()); auto b = biases(params, weights.size());
tt::add(1,output,1,b); tt::add(1,output,1,b);
...@@ -1641,7 +1737,7 @@ namespace dlib ...@@ -1641,7 +1737,7 @@ namespace dlib
auto pw = weights(params_grad, 0); auto pw = weights(params_grad, 0);
tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false); tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);
if (bias_mode == FC_HAS_BIAS) if (bias_mode == FC_HAS_BIAS && use_bias)
{ {
// compute the gradient of the bias parameters. // compute the gradient of the bias parameters.
auto pb = biases(params_grad, weights.size()); auto pb = biases(params_grad, weights.size());
...@@ -1683,7 +1779,7 @@ namespace dlib ...@@ -1683,7 +1779,7 @@ namespace dlib
friend void serialize(const fc_& item, std::ostream& out) friend void serialize(const fc_& item, std::ostream& out)
{ {
serialize("fc_2", out); serialize("fc_3", out);
serialize(item.num_outputs, out); serialize(item.num_outputs, out);
serialize(item.num_inputs, out); serialize(item.num_inputs, out);
serialize(item.params, out); serialize(item.params, out);
...@@ -1694,27 +1790,36 @@ namespace dlib ...@@ -1694,27 +1790,36 @@ namespace dlib
serialize(item.weight_decay_multiplier, out); serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out); serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out); serialize(item.bias_weight_decay_multiplier, out);
serialize(item.use_bias, out);
} }
friend void deserialize(fc_& item, std::istream& in) friend void deserialize(fc_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "fc_2") if (version == "fc_2" || version == "fc_3")
{
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
deserialize(item.weights, in);
deserialize(item.biases, in);
int bmode = 0;
deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
if (version == "fc_3")
{
deserialize(item.use_bias, in);
}
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
}
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
deserialize(item.weights, in);
deserialize(item.biases, in);
int bmode = 0;
deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
} }
friend std::ostream& operator<<(std::ostream& out, const fc_& item) friend std::ostream& operator<<(std::ostream& out, const fc_& item)
...@@ -1726,8 +1831,15 @@ namespace dlib ...@@ -1726,8 +1831,15 @@ namespace dlib
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier; out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier; out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier; if (item.use_bias)
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier; {
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
}
else
{
out << " use_bias=false";
}
} }
else else
{ {
...@@ -1749,7 +1861,8 @@ namespace dlib ...@@ -1749,7 +1861,8 @@ namespace dlib
<< " learning_rate_mult='"<<item.learning_rate_multiplier<<"'" << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
<< " weight_decay_mult='"<<item.weight_decay_multiplier<<"'" << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
<< " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'" << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
<< " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"; << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
<< " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
out << ">\n"; out << ">\n";
out << mat(item.params); out << mat(item.params);
out << "</fc>\n"; out << "</fc>\n";
...@@ -1776,6 +1889,7 @@ namespace dlib ...@@ -1776,6 +1889,7 @@ namespace dlib
double weight_decay_multiplier; double weight_decay_multiplier;
double bias_learning_rate_multiplier; double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier; double bias_weight_decay_multiplier;
bool use_bias;
}; };
template < template <
......
...@@ -573,6 +573,22 @@ namespace dlib ...@@ -573,6 +573,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val - #get_bias_weight_decay_multiplier() == val
!*/ !*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
alias_tensor_const_instance get_weights( alias_tensor_const_instance get_weights(
) const; ) const;
/*! /*!
...@@ -903,6 +919,22 @@ namespace dlib ...@@ -903,6 +919,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val - #get_bias_weight_decay_multiplier() == val
!*/ !*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -1147,6 +1179,22 @@ namespace dlib ...@@ -1147,6 +1179,22 @@ namespace dlib
- #get_bias_weight_decay_multiplier() == val - #get_bias_weight_decay_multiplier() == val
!*/ !*/
void disable_bias(
);
/*!
ensures
- bias_is_disabled() returns true
!*/
bool bias_is_disabled(
) const;
/*!
ensures
- returns true if bias learning is disabled for this layer. This means the biases will
not be learned during the training and they will not be used in the forward or backward
methods either.
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -1616,6 +1664,22 @@ namespace dlib ...@@ -1616,6 +1664,22 @@ namespace dlib
new_window_size. new_window_size.
!*/ !*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
void set_all_bn_inputs_no_bias (
const net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Disables bias for all bn_ layer inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ layer inputs.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class affine_ class affine_
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
by Alec Radford, Luke Metz, Soumith Chintala. by Alec Radford, Luke Metz, Soumith Chintala.
The main idea is that there are two neural networks training at the same time: The main idea is that there are two neural networks training at the same time:
- the generator is in charge of generating images that look as close as possible as the - the generator is in charge of generating images that look as close as possible to the
ones from the dataset. ones from the dataset.
- the discriminator will decide whether an image is fake (created by the generator) or real - the discriminator will decide whether an image is fake (created by the generator) or real
(selected from the dataset). (selected from the dataset).
...@@ -35,25 +35,6 @@ ...@@ -35,25 +35,6 @@
using namespace std; using namespace std;
using namespace dlib; using namespace dlib;
// We start by defining a simple visitor to disable bias learning in a network. By default,
// biases are initialized to 0, so setting the multipliers to 0 disables bias learning.
class visitor_no_bias
{
public:
template <typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T, U, E>& l) const
{
set_bias_learning_rate_multiplier(l.layer_details(), 0);
set_bias_weight_decay_multiplier(l.layer_details(), 0);
}
};
// Some helper definitions for the noise generation // Some helper definitions for the noise generation
const size_t noise_size = 100; const size_t noise_size = 100;
using noise_t = std::array<matrix<float, 1, 1>, noise_size>; using noise_t = std::array<matrix<float, 1, 1>, noise_size>;
...@@ -149,16 +130,15 @@ int main(int argc, char** argv) try ...@@ -149,16 +130,15 @@ int main(int argc, char** argv) try
// Instantiate both generator and discriminator // Instantiate both generator and discriminator
generator_type generator; generator_type generator;
discriminator_type discriminator( discriminator_type discriminator(leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2));
leaky_relu_(0.2), leaky_relu_(0.2), leaky_relu_(0.2)); // Remove the bias learning from all bn_ inputs in both networks
// Remove the bias learning from the networks set_all_bn_inputs_no_bias(generator);
visit_layers(generator, visitor_no_bias()); set_all_bn_inputs_no_bias(discriminator);
visit_layers(discriminator, visitor_no_bias());
// Forward random noise so that we see the tensor size at each layer // Forward random noise so that we see the tensor size at each layer
discriminator(generate_image(generator, make_noise(rnd))); discriminator(generate_image(generator, make_noise(rnd)));
cout << "generator" << endl; cout << "generator (" << count_parameters(generator) << " parameters)" << endl;
cout << generator << endl; cout << generator << endl;
cout << "discriminator" << endl; cout << "discriminator (" << count_parameters(discriminator) << " parameters)" << endl;
cout << discriminator << endl; cout << discriminator << endl;
// The solvers for the generator and discriminator networks. In this example, we are going to // The solvers for the generator and discriminator networks. In this example, we are going to
...@@ -204,7 +184,7 @@ int main(int argc, char** argv) try ...@@ -204,7 +184,7 @@ int main(int argc, char** argv) try
{ {
noises.push_back(make_noise(rnd)); noises.push_back(make_noise(rnd));
} }
// 2. Convert noises into a tensor // 2. Convert noises into a tensor
generator.to_tensor(noises.begin(), noises.end(), noises_tensor); generator.to_tensor(noises.begin(), noises.end(), noises_tensor);
// 3. Forward the noise through the network and convert the outputs into images. // 3. Forward the noise through the network and convert the outputs into images.
const auto fake_samples = get_generated_images(generator.forward(noises_tensor)); const auto fake_samples = get_generated_images(generator.forward(noises_tensor));
...@@ -257,8 +237,11 @@ int main(int argc, char** argv) try ...@@ -257,8 +237,11 @@ int main(int argc, char** argv) try
// output. // output.
while (!win.is_closed()) while (!win.is_closed())
{ {
win.set_image(generate_image(generator, make_noise(rnd))); const auto image = generate_image(generator, make_noise(rnd));
cout << "Hit enter to generate a new image"; const auto real = discriminator(image) > 0;
win.set_image(image);
cout << "The discriminator thinks it's " << (real ? "real" : "fake");
cout << ". Hit enter to generate a new image";
cin.get(); cin.get();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment