Unverified Commit 516b744b authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add missing vistor implementations to visitors.h (#2539)

Notably, set_all_bn_running_stats_window_sizes and fuse_layers.

But also I took the chance to remove the superflous separators and
change the attribute of upsample layers from stride to scale.
parent 12f1b3a3
...@@ -1708,136 +1708,6 @@ namespace dlib ...@@ -1708,136 +1708,6 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>; using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_bn_running_stats_window_size
{
public:
visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
template <typename T>
void set_window_size(T&) const
{
// ignore other layer detail types
}
template < layer_mode mode >
void set_window_size(bn_<mode>& l) const
{
l.set_running_stats_window_size(new_window_size);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_window_size(l.layer_details());
}
private:
unsigned long new_window_size;
};
class visitor_disable_input_bias
{
public:
template <typename T>
void disable_input_bias(T&) const
{
// ignore other layer types
}
// handle the standard case
template <typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template <layer_mode mode, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
// handle input repeat layer case
template <layer_mode mode, size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
template <size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
// handle input repeat layer with tag case
template <layer_mode mode, unsigned long ID, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
template <unsigned long ID, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
// handle tag layer case
template <layer_mode mode, unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& )
{
}
template <unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, U>, E>& )
{
}
// handle skip layer case
template <layer_mode mode, template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& )
{
}
template <template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_skip_layer<TAG, U>, E>& )
{
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
disable_input_bias(l);
}
};
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum fc_bias_mode enum fc_bias_mode
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment