Unverified Commit 516b744b authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add missing vistor implementations to visitors.h (#2539)

Notably, set_all_bn_running_stats_window_sizes and fuse_layers.

But also I took the chance to remove the superflous separators and
change the attribute of upsample layers from stride to scale.
parent 12f1b3a3
...@@ -1708,136 +1708,6 @@ namespace dlib ...@@ -1708,136 +1708,6 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>; using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_bn_running_stats_window_size
{
public:
visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
template <typename T>
void set_window_size(T&) const
{
// ignore other layer detail types
}
template < layer_mode mode >
void set_window_size(bn_<mode>& l) const
{
l.set_running_stats_window_size(new_window_size);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_window_size(l.layer_details());
}
private:
unsigned long new_window_size;
};
class visitor_disable_input_bias
{
public:
template <typename T>
void disable_input_bias(T&) const
{
// ignore other layer types
}
// handle the standard case
template <typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template <layer_mode mode, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
// handle input repeat layer case
template <layer_mode mode, size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
template <size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
// handle input repeat layer with tag case
template <layer_mode mode, unsigned long ID, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
template <unsigned long ID, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
// handle tag layer case
template <layer_mode mode, unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& )
{
}
template <unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, U>, E>& )
{
}
// handle skip layer case
template <layer_mode mode, template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& )
{
}
template <template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_skip_layer<TAG, U>, E>& )
{
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
disable_input_bias(l);
}
};
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum fc_bias_mode enum fc_bias_mode
......
...@@ -236,6 +236,45 @@ namespace dlib ...@@ -236,6 +236,45 @@ namespace dlib
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_bn_running_stats_window_size
{
public:
visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
template <typename T>
void set_window_size(T&) const
{
// ignore other layer detail types
}
template < layer_mode mode >
void set_window_size(bn_<mode>& l) const
{
l.set_running_stats_window_size(new_window_size);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_window_size(l.layer_details());
}
private:
unsigned long new_window_size;
};
}
template <typename net_type> template <typename net_type>
void set_all_bn_running_stats_window_sizes ( void set_all_bn_running_stats_window_sizes (
net_type& net, net_type& net,
...@@ -245,6 +284,102 @@ namespace dlib ...@@ -245,6 +284,102 @@ namespace dlib
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size)); visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
} }
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_disable_input_bias
{
public:
template <typename T>
void disable_input_bias(T&) const
{
// ignore other layer types
}
// handle the standard case
template <typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template <layer_mode mode, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
// handle input repeat layer case
template <layer_mode mode, size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
template <size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
// handle input repeat layer with tag case
template <layer_mode mode, unsigned long ID, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
template <unsigned long ID, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, impl::repeat_input_layer>, E>& )
{
}
// handle tag layer case
template <layer_mode mode, unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, U>, E>& )
{
}
template <unsigned long ID, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, U>, E>& )
{
}
// handle skip layer case
template <layer_mode mode, template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, add_skip_layer<TAG, U>, E>& )
{
}
template <template <typename> class TAG, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, add_skip_layer<TAG, U>, E>& )
{
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
disable_input_bias(l);
}
};
}
template <typename net_type> template <typename net_type>
void disable_duplicative_biases ( void disable_duplicative_biases (
net_type& net net_type& net
...@@ -253,7 +388,6 @@ namespace dlib ...@@ -253,7 +388,6 @@ namespace dlib
visit_layers(net, impl::visitor_disable_input_bias()); visit_layers(net, impl::visitor_disable_input_bias());
} }
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
...@@ -331,7 +465,6 @@ namespace dlib ...@@ -331,7 +465,6 @@ namespace dlib
visit_layers(net, impl::visitor_fuse_layers()); visit_layers(net, impl::visitor_fuse_layers());
} }
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
...@@ -408,7 +541,6 @@ namespace dlib ...@@ -408,7 +541,6 @@ namespace dlib
net_to_xml(net, fout); net_to_xml(net, fout);
} }
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
...@@ -419,8 +551,6 @@ namespace dlib ...@@ -419,8 +551,6 @@ namespace dlib
visitor_net_to_dot(std::ostream& out) : out(out) {} visitor_net_to_dot(std::ostream& out) : out(out) {}
// ----------------------------------------------------------------------------------------
template <typename input_layer_type> template <typename input_layer_type>
void operator()(size_t i, input_layer_type& l) void operator()(size_t i, input_layer_type& l)
{ {
...@@ -429,8 +559,6 @@ namespace dlib ...@@ -429,8 +559,6 @@ namespace dlib
from = i; from = i;
} }
// ----------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
void operator()(size_t i, const add_loss_layer<T, U>&) void operator()(size_t i, const add_loss_layer<T, U>&)
{ {
...@@ -439,8 +567,6 @@ namespace dlib ...@@ -439,8 +567,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class... TAGS, typename U> template <template <typename> class... TAGS, typename U>
void operator()(size_t i, const add_loss_layer<loss_yolo_<TAGS...>, U>&) void operator()(size_t i, const add_loss_layer<loss_yolo_<TAGS...>, U>&)
{ {
...@@ -453,8 +579,6 @@ namespace dlib ...@@ -453,8 +579,6 @@ namespace dlib
out << tag_to_layer.at(std::stoul(tag)) << " -> " << i << '\n'; out << tag_to_layer.at(std::stoul(tag)) << " -> " << i << '\n';
} }
// ----------------------------------------------------------------------------------------
template <unsigned long ID, typename U, typename E> template <unsigned long ID, typename U, typename E>
void operator()(size_t i, const add_tag_layer<ID, U, E>&) void operator()(size_t i, const add_tag_layer<ID, U, E>&)
{ {
...@@ -484,8 +608,6 @@ namespace dlib ...@@ -484,8 +608,6 @@ namespace dlib
// update(i); // update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U> template <template <typename> class TAG, typename U>
void operator()(size_t i, const add_skip_layer<TAG, U>&) void operator()(size_t i, const add_skip_layer<TAG, U>&)
{ {
...@@ -493,8 +615,6 @@ namespace dlib ...@@ -493,8 +615,6 @@ namespace dlib
from = tag_to_layer.at(t); from = tag_to_layer.at(t);
} }
// ----------------------------------------------------------------------------------------
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E> template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
void operator()(size_t i, const add_layer<con_<nf, nr, nc, sy, sx, py, px>, U, E>& l) void operator()(size_t i, const add_layer<con_<nf, nr, nc, sy, sx, py, px>, U, E>& l)
{ {
...@@ -509,8 +629,6 @@ namespace dlib ...@@ -509,8 +629,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E> template <long nf, long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
void operator()(size_t i, const add_layer<cont_<nf, nr, nc, sy, sx, py, px>, U, E>& l) void operator()(size_t i, const add_layer<cont_<nf, nr, nc, sy, sx, py, px>, U, E>& l)
{ {
...@@ -525,20 +643,16 @@ namespace dlib ...@@ -525,20 +643,16 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <int sy, int sx, typename U, typename E> template <int sy, int sx, typename U, typename E>
void operator()(size_t i, const add_layer<upsample_<sy, sx>, U, E>&) void operator()(size_t i, const add_layer<upsample_<sy, sx>, U, E>&)
{ {
start_node(i, "upsample"); start_node(i, "upsample");
if (sy != 1 || sx != 1) if (sy != 1 || sx != 1)
out << " | {stride|{" << sy<< "," << sx << "}}"; out << " | {scale|{" << sy<< "," << sx << "}}";
end_node(); end_node();
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <int NR, int NC, typename U, typename E> template <int NR, int NC, typename U, typename E>
void operator()(size_t i, const add_layer<resize_to_<NR, NC>, U, E>&) void operator()(size_t i, const add_layer<resize_to_<NR, NC>, U, E>&)
{ {
...@@ -548,8 +662,6 @@ namespace dlib ...@@ -548,8 +662,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <long nr, long nc, int sy, int sx, int py, int px, typename U, typename E> template <long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
void operator()(size_t i, const add_layer<max_pool_<nr, nc, sy, sx, py, px>, U, E>&) void operator()(size_t i, const add_layer<max_pool_<nr, nc, sy, sx, py, px>, U, E>&)
{ {
...@@ -563,8 +675,6 @@ namespace dlib ...@@ -563,8 +675,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <long nr, long nc, int sy, int sx, int py, int px, typename U, typename E> template <long nr, long nc, int sy, int sx, int py, int px, typename U, typename E>
void operator()(size_t i, const add_layer<avg_pool_<nr, nc, sy, sx, py, px>, U, E>&) void operator()(size_t i, const add_layer<avg_pool_<nr, nc, sy, sx, py, px>, U, E>&)
{ {
...@@ -578,8 +688,6 @@ namespace dlib ...@@ -578,8 +688,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<layer_norm_, U, E>&) void operator()(size_t i, const add_layer<layer_norm_, U, E>&)
{ {
...@@ -588,8 +696,6 @@ namespace dlib ...@@ -588,8 +696,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <layer_mode MODE, typename U, typename E> template <layer_mode MODE, typename U, typename E>
void operator()(size_t i, const add_layer<bn_<MODE>, U, E>&) void operator()(size_t i, const add_layer<bn_<MODE>, U, E>&)
{ {
...@@ -598,8 +704,6 @@ namespace dlib ...@@ -598,8 +704,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <unsigned long no, fc_bias_mode bm, typename U, typename E> template <unsigned long no, fc_bias_mode bm, typename U, typename E>
void operator()(size_t i, const add_layer<fc_<no, bm>, U, E>& l) void operator()(size_t i, const add_layer<fc_<no, bm>, U, E>& l)
{ {
...@@ -609,8 +713,6 @@ namespace dlib ...@@ -609,8 +713,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<dropout_, U, E>&) void operator()(size_t i, const add_layer<dropout_, U, E>&)
{ {
...@@ -619,8 +721,6 @@ namespace dlib ...@@ -619,8 +721,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<multiply_, U, E>&) void operator()(size_t i, const add_layer<multiply_, U, E>&)
{ {
...@@ -629,8 +729,6 @@ namespace dlib ...@@ -629,8 +729,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<affine_, U, E>&) void operator()(size_t i, const add_layer<affine_, U, E>&)
{ {
...@@ -639,8 +737,6 @@ namespace dlib ...@@ -639,8 +737,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U, typename E> template <template <typename> class TAG, typename U, typename E>
void operator()(size_t i, const add_layer<add_prev_<TAG>, U, E>&) void operator()(size_t i, const add_layer<add_prev_<TAG>, U, E>&)
{ {
...@@ -651,8 +747,6 @@ namespace dlib ...@@ -651,8 +747,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U, typename E> template <template <typename> class TAG, typename U, typename E>
void operator()(size_t i, const add_layer<mult_prev_<TAG>, U, E>&) void operator()(size_t i, const add_layer<mult_prev_<TAG>, U, E>&)
{ {
...@@ -663,8 +757,6 @@ namespace dlib ...@@ -663,8 +757,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U, typename E> template <template <typename> class TAG, typename U, typename E>
void operator()(size_t i, const add_layer<resize_prev_to_tagged_<TAG>, U, E>&) void operator()(size_t i, const add_layer<resize_prev_to_tagged_<TAG>, U, E>&)
{ {
...@@ -676,8 +768,6 @@ namespace dlib ...@@ -676,8 +768,6 @@ namespace dlib
from = i; from = i;
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U, typename E> template <template <typename> class TAG, typename U, typename E>
void operator()(size_t i, const add_layer<scale_<TAG>, U, E>&) void operator()(size_t i, const add_layer<scale_<TAG>, U, E>&)
{ {
...@@ -688,8 +778,6 @@ namespace dlib ...@@ -688,8 +778,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class TAG, typename U, typename E> template <template <typename> class TAG, typename U, typename E>
void operator()(size_t i, const add_layer<scale_prev_<TAG>, U, E>&) void operator()(size_t i, const add_layer<scale_prev_<TAG>, U, E>&)
{ {
...@@ -700,8 +788,6 @@ namespace dlib ...@@ -700,8 +788,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<relu_, U, E>&) void operator()(size_t i, const add_layer<relu_, U, E>&)
{ {
...@@ -710,8 +796,6 @@ namespace dlib ...@@ -710,8 +796,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<prelu_, U, E>&) void operator()(size_t i, const add_layer<prelu_, U, E>&)
{ {
...@@ -720,8 +804,6 @@ namespace dlib ...@@ -720,8 +804,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<leaky_relu_, U, E>&) void operator()(size_t i, const add_layer<leaky_relu_, U, E>&)
{ {
...@@ -730,8 +812,6 @@ namespace dlib ...@@ -730,8 +812,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<sig_, U, E>&) void operator()(size_t i, const add_layer<sig_, U, E>&)
{ {
...@@ -740,8 +820,6 @@ namespace dlib ...@@ -740,8 +820,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<mish_, U, E>&) void operator()(size_t i, const add_layer<mish_, U, E>&)
{ {
...@@ -750,8 +828,6 @@ namespace dlib ...@@ -750,8 +828,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<htan_, U, E>&) void operator()(size_t i, const add_layer<htan_, U, E>&)
{ {
...@@ -760,8 +836,6 @@ namespace dlib ...@@ -760,8 +836,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<clipped_relu_, U, E>&) void operator()(size_t i, const add_layer<clipped_relu_, U, E>&)
{ {
...@@ -770,8 +844,6 @@ namespace dlib ...@@ -770,8 +844,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<elu_, U, E>&) void operator()(size_t i, const add_layer<elu_, U, E>&)
{ {
...@@ -780,8 +852,6 @@ namespace dlib ...@@ -780,8 +852,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<gelu_, U, E>&) void operator()(size_t i, const add_layer<gelu_, U, E>&)
{ {
...@@ -790,8 +860,6 @@ namespace dlib ...@@ -790,8 +860,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<softmax_, U, E>&) void operator()(size_t i, const add_layer<softmax_, U, E>&)
{ {
...@@ -800,8 +868,6 @@ namespace dlib ...@@ -800,8 +868,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<softmax_all_, U, E>&) void operator()(size_t i, const add_layer<softmax_all_, U, E>&)
{ {
...@@ -810,8 +876,6 @@ namespace dlib ...@@ -810,8 +876,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <template <typename> class... TAGS, typename U, typename E> template <template <typename> class... TAGS, typename U, typename E>
void operator()(size_t i, const add_layer<concat_<TAGS...>, U, E>& l) void operator()(size_t i, const add_layer<concat_<TAGS...>, U, E>& l)
{ {
...@@ -825,8 +889,6 @@ namespace dlib ...@@ -825,8 +889,6 @@ namespace dlib
from = i; from = i;
} }
// ----------------------------------------------------------------------------------------
template <typename U, typename E> template <typename U, typename E>
void operator()(size_t i, const add_layer<l2normalize_, U, E>&) void operator()(size_t i, const add_layer<l2normalize_, U, E>&)
{ {
...@@ -835,8 +897,6 @@ namespace dlib ...@@ -835,8 +897,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <long offset, long k, int nr, int nc, typename U, typename E> template <long offset, long k, int nr, int nc, typename U, typename E>
void operator()(size_t i, const add_layer<extract_<offset, k, nr, nc>, U, E>&) void operator()(size_t i, const add_layer<extract_<offset, k, nr, nc>, U, E>&)
{ {
...@@ -849,8 +909,6 @@ namespace dlib ...@@ -849,8 +909,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <long long sy, long long sx, typename U, typename E> template <long long sy, long long sx, typename U, typename E>
void operator()(size_t i, const add_layer<reorg_<sy, sx>, U, E>&) void operator()(size_t i, const add_layer<reorg_<sy, sx>, U, E>&)
{ {
...@@ -861,8 +919,6 @@ namespace dlib ...@@ -861,8 +919,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
template <typename T, typename U, typename E> template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>& l) void operator()(size_t i, const add_layer<T, U, E>& l)
{ {
...@@ -870,8 +926,6 @@ namespace dlib ...@@ -870,8 +926,6 @@ namespace dlib
update(i); update(i);
} }
// ----------------------------------------------------------------------------------------
private: private:
size_t from; size_t from;
std::ostream& out; std::ostream& out;
...@@ -914,6 +968,9 @@ namespace dlib ...@@ -914,6 +968,9 @@ namespace dlib
std::ofstream fout(filename); std::ofstream fout(filename);
net_to_dot(net, fout); net_to_dot(net, fout);
} }
// ----------------------------------------------------------------------------------------
} }
#endif // DLIB_DNn_VISITORS_H_ #endif // DLIB_DNn_VISITORS_H_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment