You need to sign in or sign up before continuing.
Unverified Commit 994df341 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

add option to not zero out gradients and method to do it (#2477)

parent a54cea44
......@@ -729,6 +729,14 @@ namespace dlib
};
}
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no = 0,
yes = 1
};
// ----------------------------------------------------------------------------------------
template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
......@@ -1003,21 +1011,28 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return subnetwork->get_final_data_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(),
gradient_input, wsub, static_cast<tensor&>(params_grad));
subnetwork->back_propagate_error(x);
subnetwork->back_propagate_error(x, zero_grads);
// zero out get_gradient_input()
gradient_input_is_stale = true;
gradient_input_is_stale = zero_grads == zero_gradients::yes;
}
template <typename solver_type>
......@@ -1057,6 +1072,12 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
subnetwork->set_gradient_inputs_to_zero();
}
void clean()
{
x_grad.clear();
......@@ -1374,11 +1395,18 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return grad_final; }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
// make sure grad_final is initialized to 0
if (!have_same_dimensions(x, grad_final))
......@@ -1391,7 +1419,7 @@ namespace dlib
gradient_input, wsub, static_cast<tensor&>(params_grad));
// zero out get_gradient_input()
gradient_input_is_stale = true;
gradient_input_is_stale = zero_grads == zero_gradients::yes;
}
template <typename solver_type>
......@@ -1430,6 +1458,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return _sample_expansion_factor; }
void set_gradient_inputs_to_zero()
{
gradient_input_is_stale = true;
}
void clean()
{
x_grad.clear();
......@@ -1642,13 +1675,20 @@ namespace dlib
const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x,gradient_input);
subnetwork.back_propagate_error(x,gradient_input, zero_grads);
}
template <typename solver_type>
......@@ -1677,6 +1717,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
subnetwork.clean();
......@@ -1934,28 +1979,35 @@ namespace dlib
tensor& get_parameter_gradient (
) { return details[0].get_parameter_gradient(); }
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
back_propagate_error(x, private_get_gradient_input());
back_propagate_error(x, private_get_gradient_input(), zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
if (details.size() > 1)
{
details[0].back_propagate_error(details[1].get_output(), gradient_input);
details[0].back_propagate_error(details[1].get_output(), gradient_input, zero_grads);
for (size_t i = 1; i < details.size(); ++i)
{
if (i+1 < details.size())
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient());
details[i].back_propagate_error(details[i+1].get_output(), details[i-1].get_final_data_gradient(), zero_grads);
else
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient());
details[i].back_propagate_error(subnetwork.get_output(), details[i-1].get_final_data_gradient(), zero_grads);
}
}
else
{
details[0].back_propagate_error(subnetwork.get_output(), gradient_input);
details[0].back_propagate_error(subnetwork.get_output(), gradient_input, zero_grads);
}
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient());
subnetwork.back_propagate_error(x, details.back().get_final_data_gradient(), zero_grads);
}
template <typename solver_type>
......@@ -1980,6 +2032,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
temp_tensor.clear();
......@@ -2191,11 +2248,19 @@ namespace dlib
return grad_final;
}
void back_propagate_error(const tensor& /*x*/)
void back_propagate_error(
const tensor& /*x*/,
zero_gradients zero_grads = zero_gradients::yes
)
{
// nothing to do
}
void back_propagate_error(const tensor& /*x*/, const tensor& /*gradient_input*/)
void back_propagate_error(
const tensor& /*x*/,
const tensor& /*gradient_input*/,
zero_gradients zero_grads = zero_gradients::yes
)
{
// nothing to do
}
......@@ -2218,6 +2283,11 @@ namespace dlib
const input_layer_type& input_layer() const { return input_layer_; }
input_layer_type& input_layer() { return input_layer_; }
void set_gradient_inputs_to_zero()
{
// nothing to do
}
void clean()
{
grad_final.clear();
......@@ -2518,14 +2588,21 @@ namespace dlib
return results;
}
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnet().back_propagate_error(x);
subnet().back_propagate_error(x, zero_grads);
}
void back_propagate_error(const tensor& x, const tensor& gradient_input)
void back_propagate_error(
const tensor& x,
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnet().back_propagate_error(x, gradient_input);
subnet().back_propagate_error(x, gradient_input, zero_grads);
}
const tensor& get_final_data_gradient(
......@@ -2604,43 +2681,47 @@ namespace dlib
template <typename label_iterator>
double compute_parameter_gradients (
const tensor& x,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, lbegin, wsub);
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
return l;
}
template <typename forward_iterator, typename label_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
)
{
to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor, lbegin);
return compute_parameter_gradients(temp_tensor, lbegin, zero_grads);
}
double compute_parameter_gradients (
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.forward(x);
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
double l = loss.compute_loss_value_and_gradient(x, wsub);
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
return l;
}
template <typename forward_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend
forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
)
{
to_tensor(ibegin,iend,temp_tensor);
return compute_parameter_gradients(temp_tensor);
return compute_parameter_gradients(temp_tensor, zero_grads);
}
template <typename solver_type>
......@@ -2667,6 +2748,12 @@ namespace dlib
const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; }
void set_gradient_inputs_to_zero (
)
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean (
)
{
......@@ -3022,9 +3109,12 @@ namespace dlib
return subnetwork.get_final_data_gradient();
}
void back_propagate_error(const tensor& x)
void back_propagate_error(
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
)
{
subnetwork.back_propagate_error(x);
subnetwork.back_propagate_error(x, zero_grads);
}
template <typename solver_type>
......@@ -3061,6 +3151,11 @@ namespace dlib
unsigned int sample_expansion_factor() const { return subnet().sample_expansion_factor(); }
void set_gradient_inputs_to_zero()
{
subnetwork.set_gradient_inputs_to_zero();
}
void clean()
{
subnetwork.clean();
......
......@@ -275,6 +275,14 @@ namespace dlib
- returns a sstack that sits on top of the given std::vector.
!*/
// ----------------------------------------------------------------------------------------
enum class zero_gradients : uint8_t
{
no,
yes
};
// ----------------------------------------------------------------------------------------
template <
......@@ -603,7 +611,8 @@ namespace dlib
!*/
void back_propagate_error(
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -617,7 +626,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient().
- All elements of #get_gradient_input() are set to 0.
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with
......@@ -626,7 +635,8 @@ namespace dlib
void back_propagate_error(
const tensor& x,
const tensor& gradient_input
const tensor& gradient_input,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -643,7 +653,7 @@ namespace dlib
back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient.
- All elements of #get_gradient_input() are set to 0.
- All elements of #get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- have_same_dimensions(#get_parameter_gradient(), layer_details().get_layer_params()) == true.
- #get_final_data_gradient() contains the gradient of the network with
......@@ -681,6 +691,20 @@ namespace dlib
Convenience method for calling update_parameters()
!*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
That is, for each layer, we will have:
- get_gradient_input() == 0
- Note that You only need to call this method if you manually called either
- back_propagate_error
- compute_parameter_gradients
with the zero_grads parameter set to zero_gradients::no.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
void clean(
);
/*!
......@@ -1147,7 +1171,8 @@ namespace dlib
template <typename label_iterator>
double compute_parameter_gradients (
const tensor& x,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -1164,6 +1189,7 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k:
- the expected label of the kth sample in x is *(lbegin+k/sample_expansion_factor()).
- returns compute_loss(x,lbegin)
......@@ -1173,7 +1199,8 @@ namespace dlib
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend,
label_iterator lbegin
label_iterator lbegin,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -1187,13 +1214,15 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- for all valid k:
- the expected label of *(ibegin+k) is *(lbegin+k).
- returns compute_loss(ibegin,iend,lbegin)
!*/
double compute_parameter_gradients (
const tensor& x
const tensor& x,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -1208,13 +1237,15 @@ namespace dlib
respect to the loss, via backpropagation. Specifically, this function
updates get_final_data_gradient() and also, for each layer, the tensor
returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(x)
!*/
template <typename forward_iterator>
double compute_parameter_gradients (
forward_iterator ibegin,
forward_iterator iend
forward_iterator iend,
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -1226,6 +1257,7 @@ namespace dlib
gradients with respect to the loss, via backpropagation. Specifically,
this function updates get_final_data_gradient() and also, for each layer,
the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- returns compute_loss(ibegin,iend)
!*/
......@@ -1262,6 +1294,7 @@ namespace dlib
void back_propagate_error(
const tensor& x
zero_gradients zero_grads = zero_gradients::yes
);
/*!
requires
......@@ -1276,7 +1309,7 @@ namespace dlib
network and computes parameter and data gradients, via backpropagation.
Specifically, this function populates get_final_data_gradient() and also,
for each layer, the tensor returned by get_parameter_gradient().
- All elements of #subnet().get_gradient_input() are set to 0.
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
......@@ -1301,7 +1334,7 @@ namespace dlib
back_propagate_error(x);
Except that calling back_propagate_error(x,gradient_input) avoids the
copy and is therefore slightly more efficient.
- All elements of #subnet.get_gradient_input() are set to 0.
- All elements of #subnet().get_gradient_input() are set to 0 if zero_grads == zero_gradients::yes.
- have_same_dimensions(#get_final_data_gradient(), x) == true.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
......@@ -1319,6 +1352,13 @@ namespace dlib
not one per layer, since there is only one input to the entire network.
!*/
void set_gradient_inputs_to_zero(
);
/*!
ensures
- Sets all elements in all gradient inputs in the network to 0.
- invokes subnet().set_gradient_inputs_to_zero()
!*/
// -------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment