Unverified Commit 0ffe9c4c authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Fix input/output mappings with repeat layers (#2337)

* Fix input/output mappings with repeat layers

* add test for input/output tensor mappers

* fix output to input order
parent a4713b59
......@@ -141,6 +141,15 @@ namespace dlib
(*this)(net.subnet());
p = net.layer_details().map_input_to_output(p);
}
template <size_t N, template <typename> class R, typename U>
void operator()(const repeat<N, R, U>& net)
{
(*this)(net.subnet());
for (size_t i = 0; i < N; ++i)
{
(*this)(net.get_repeated_layer(N-1-i).subnet());
}
}
template <unsigned long ID, typename U, typename E>
......@@ -201,6 +210,15 @@ namespace dlib
p = net.layer_details().map_output_to_input(p);
(*this)(net.subnet());
}
template <size_t N, template <typename> class R, typename U>
void operator()(const repeat<N, R, U>& net)
{
for (size_t i = 0; i < N; ++i)
{
(*this)(net.get_repeated_layer(i).subnet());
}
(*this)(net.subnet());
}
template <unsigned long ID, typename U, typename E>
......
......@@ -3970,6 +3970,20 @@ namespace
DLIB_TEST(layer<4>(net).layer_details().get_learning_rate_multiplier() == 0.01);
}
// ----------------------------------------------------------------------------------------
template <typename SUBNET>
using conblock = relu<bn_con<add_layer<con_<16, 3, 3, 2, 2, 1, 1>, SUBNET>>>;
void test_input_ouput_mappers()
{
using net_type = loss_binary_log_per_pixel<con<1, 1, 1, 1, 1,repeat<3, conblock, tag1<input_rgb_image>>>>;
net_type net;
point p(32, 32);
DLIB_TEST(input_tensor_to_output_tensor(net, p) == p / 8);
DLIB_TEST(output_tensor_to_input_tensor(net, p) == p * 8);
}
// ----------------------------------------------------------------------------------------
// This test really just checks if the mmod loss goes negative when a whole lot of overlapping
......@@ -4157,6 +4171,7 @@ namespace
test_layers_scale_and_scale_prev();
test_disable_duplicative_biases();
test_set_learning_rate_multipliers();
test_input_ouput_mappers();
}
void perform_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment