Unverified Commit a7627cbd authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Rename function to disable_duplicative_biases (#2246)

* Rename function to disable_duplicative_biases

* rename also the functions in the tests... oops
parent b6bf8aef
...@@ -1791,7 +1791,7 @@ namespace dlib ...@@ -1791,7 +1791,7 @@ namespace dlib
} }
template <typename net_type> template <typename net_type>
void disable_duplicative_bias ( void disable_duplicative_biases (
net_type& net net_type& net
) )
{ {
......
...@@ -1810,7 +1810,7 @@ namespace dlib ...@@ -1810,7 +1810,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename net_type> template <typename net_type>
void disable_duplicative_bias ( void disable_duplicative_biases (
const net_type& net const net_type& net
); );
/*! /*!
......
...@@ -3918,7 +3918,7 @@ namespace ...@@ -3918,7 +3918,7 @@ namespace
relu<bn_con<conp<4 * growth_rate, 1, 1, relu<bn_con<conp<4 * growth_rate, 1, 1,
relu<bn_con<tag1<SUBNET>>>>>>>>>; relu<bn_con<tag1<SUBNET>>>>>>>>>;
template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>; template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>;
void test_disable_duplicative_bias() void test_disable_duplicative_biases()
{ {
using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20, using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20,
relu<layer_norm<conp<32, 3, 1, relu<layer_norm<conp<32, 3, 1,
...@@ -3934,7 +3934,7 @@ namespace ...@@ -3934,7 +3934,7 @@ namespace
DLIB_TEST(layer<21>(net).layer_details().bias_is_disabled() == false); DLIB_TEST(layer<21>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<24>(net).layer_details().bias_is_disabled() == false); DLIB_TEST(layer<24>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<31>(net).layer_details().bias_is_disabled() == false); DLIB_TEST(layer<31>(net).layer_details().bias_is_disabled() == false);
disable_duplicative_bias(net); disable_duplicative_biases(net);
DLIB_TEST(layer<0>(net).layer_details().bias_is_disabled() == false); DLIB_TEST(layer<0>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<3>(net).layer_details().bias_is_disabled() == true); DLIB_TEST(layer<3>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<6>(net).layer_details().bias_is_disabled() == true); DLIB_TEST(layer<6>(net).layer_details().bias_is_disabled() == true);
...@@ -4130,7 +4130,7 @@ namespace ...@@ -4130,7 +4130,7 @@ namespace
test_loss_multimulticlass_log(); test_loss_multimulticlass_log();
test_loss_mmod(); test_loss_mmod();
test_layers_scale_and_scale_prev(); test_layers_scale_and_scale_prev();
test_disable_duplicative_bias(); test_disable_duplicative_biases();
} }
void perform_test() void perform_test()
......
...@@ -134,8 +134,8 @@ int main(int argc, char** argv) try ...@@ -134,8 +134,8 @@ int main(int argc, char** argv) try
// setup all leaky_relu_ layers in the discriminator to have alpha = 0.2 // setup all leaky_relu_ layers in the discriminator to have alpha = 0.2
visit_computational_layers(discriminator, [](leaky_relu_& l){ l = leaky_relu_(0.2); }); visit_computational_layers(discriminator, [](leaky_relu_& l){ l = leaky_relu_(0.2); });
// Remove the bias learning from all bn_ inputs in both networks // Remove the bias learning from all bn_ inputs in both networks
disable_duplicative_bias(generator); disable_duplicative_biases(generator);
disable_duplicative_bias(discriminator); disable_duplicative_biases(discriminator);
// Forward random noise so that we see the tensor size at each layer // Forward random noise so that we see the tensor size at each layer
discriminator(generate_image(generator, make_noise(rnd))); discriminator(generate_image(generator, make_noise(rnd)));
cout << "generator (" << count_parameters(generator) << " parameters)" << endl; cout << "generator (" << count_parameters(generator) << " parameters)" << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment