Unverified Commit 3c82c225 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Add Layer Normalization (#2213)

* wip: layer normalization on cpu

* wip: add cuda implementation, nor working yet

* wip: try to fix cuda implementation

* swap grid_strid_range and grid_strid_range_y: does not work yet

* fix CUDA implementation

* implement cuda gradient

* add documentation, move layer_norm, update bn_visitor

* add tests

* use stddev instead of variance in test (they are both 1, anyway)

* add test for means and invstds on CPU and CUDA

* rename visitor to disable_duplicative_bias

* handle more cases in the visitor_disable_input_bias

* Add tests for visitor_disable_input_bias
parent 50748503
......@@ -1258,6 +1258,177 @@ namespace dlib
}
}
// -----------------------------------------------------------------------------------
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const tensor& src,
const tensor& gamma,
const tensor& beta
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() &&
src.num_samples() == beta.size() &&
eps > 0,
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
dest.copy_size(src);
means.set_size(src.num_samples());
invstds.set_size(src.num_samples());
// first compute means and invstds
means = 0;
invstds = 0;
const auto p_invstds = invstds.host();
const auto p_means = means.host();
auto p_src = src.host();
// compute means, and sum of squares
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
{
float val = p_src[n*num+i];
p_means[n] += val;
p_invstds[n] += val*val;
}
}
means /= num;
invstds /= num;
// copy data back to host
invstds.host(); means.host();
// compute variances
for (long n = 0; n < src.num_samples(); ++n)
{
auto var = p_invstds[n] - p_means[n] * p_means[n];
p_invstds[n] = 1.0f / std::sqrt(var + eps);
}
p_src = src.host();
auto p_dest = dest.host();
auto p_gamma = gamma.host();
auto p_beta = beta.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
{
*p_dest = (*p_src - p_means[n])*p_invstds[n];
*p_dest = (*p_dest)*p_gamma[n] + p_beta[n];
++p_src;
++p_dest;
}
}
}
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size());
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
DLIB_CASSERT(src.num_samples() == beta_grad.size());
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
gamma_grad = 0;
auto p_grad = gradient_input.host();
auto p_src = src.host();
const auto p_gamma = gamma.host();
const auto p_gamma_grad = gamma_grad.host();
const auto p_beta_grad = beta_grad.host();
const auto p_invstds = invstds.host();
const auto p_means = means.host();
resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
dmeans = 0;
const auto p_dvars = dvars.host();
const auto p_dmeans = dmeans.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
{
const float x_hat = (*p_src - p_means[n])*p_invstds[n];
p_beta_grad[n] += *p_grad;
p_gamma_grad[n] += (*p_grad)*x_hat;
const float dx = *p_grad * p_gamma[n];
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*std::pow(p_invstds[n], 3.0f);
++p_grad;
++p_src;
}
}
const float invnum = 1.0f/num;
p_grad = gradient_input.host();
p_src = src.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[n];
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
++p_grad;
++p_src;
}
}
p_grad = gradient_input.host();
p_src = src.host();
auto p_src_grad = src_grad.host();
for (long n = 0; n < src.num_samples(); ++n)
{
for (long i = 0; i < num; ++i)
{
const float dx = *p_grad * p_gamma[n];
*p_src_grad += dx*p_invstds[n] +
p_dvars[n] *2*(*p_src - p_means[n])*invnum +
p_dmeans[n]*invnum;
++p_grad;
++p_src;
++p_src_grad;
}
}
}
// -----------------------------------------------------------------------------------
void threshold (
......
......@@ -229,6 +229,30 @@ namespace dlib
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
void threshold (
......
......@@ -1749,6 +1749,169 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_layer_normalize(float* out, const float* s, float* m, float* v, const float* g, const float* b, float eps, size_t ns, size_t num)
{
// compute means and sum of squares
for (auto n : grid_stride_range_y(0, ns))
{
auto p = s + n * num;
float means = 0;
float invstds = 0;
for (auto i : grid_stride_range(0, num))
{
means += p[i];
invstds += p[i] * p[i];
}
warp_reduce_atomic_add(m[n], means/num);
warp_reduce_atomic_add(v[n], invstds/num);
}
__syncthreads();
// compute variances
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, 1))
{
auto var = v[n] - m[n] * m[n];
v[n] = 1.0f / std::sqrt(var + eps);
}
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
for (auto i : grid_stride_range(0, num))
{
const float val = (s[n*num+i]-m[n])*v[n];
out[n*num+i] = val*g[n]+b[n];
}
}
}
__global__ void _cuda_layer_normalize_gradient(float* out, float* gg, float* bg, const float* s, const float* gi, const float* m, const float* v, const float* g, float* dm, float* dv, float eps, size_t ns, size_t num)
{
for (auto n : grid_stride_range_y(0, ns))
{
float temp_bg = 0;
float temp_gg = 0;
float temp_dv = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float x_hat = (s[idx] - m[n])*v[n];
temp_bg += gi[idx];
temp_gg += gi[idx]*x_hat;
const float dx = gi[idx] * g[n];
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
}
warp_reduce_atomic_add(bg[n], temp_bg);
warp_reduce_atomic_add(gg[n], temp_gg);
warp_reduce_atomic_add(dv[n], temp_dv);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
float temp_dm = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[n];
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
// dm[n] += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
}
warp_reduce_atomic_add(dm[n], temp_dm);
}
__syncthreads();
for (auto n : grid_stride_range_y(0, ns))
{
float temp = 0;
for (auto i : grid_stride_range(0, num))
{
auto idx = n*num+i;
const float dx = gi[idx]*g[n];
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
// temp += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
}
}
}
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const tensor& src,
const tensor& gamma,
const tensor& beta
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(
have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() &&
src.num_samples() == beta.size() &&
eps > 0,
"\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() <<
"\ngamma.nc(): " << gamma.nc() <<
"\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() <<
"\neps: " << eps
);
dest.copy_size(src);
means.set_size(src.num_samples());
invstds.set_size(src.num_samples());
means = 0;
invstds = 0;
launch_kernel(_cuda_layer_normalize, max_jobs(num, src.num_samples()), dest.device(), src.device(),
means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), num);
}
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
)
{
const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size());
DLIB_CASSERT(src.num_samples() == gamma_grad.size());
DLIB_CASSERT(src.num_samples() == beta_grad.size());
DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(eps > 0);
beta_grad = 0;
gamma_grad = 0;
resizable_tensor dvars, dmeans;
dvars.copy_size(invstds);
dmeans.copy_size(means);
dvars = 0;
dmeans = 0;
launch_kernel(_cuda_layer_normalize_gradient, max_jobs(num, src.num_samples()),
src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(),
gradient_input.device(), means.device(), invstds.device(), gamma.device(),
dmeans.device(), dvars.device(), eps, src.num_samples(), num);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_copy_tensor_add_to (float* dest, size_t size, const float* src, size_t dest_stride, size_t src_stride, size_t block_size)
......
......@@ -336,6 +336,30 @@ namespace dlib
const tensor& gradient_input
);
// -----------------------------------------------------------------------------------
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
void threshold (
......
......@@ -656,6 +656,40 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& vars,
const tensor& src,
const tensor& gamma,
const tensor& beta
)
{
#ifdef DLIB_USE_CUDA
cuda::layer_normalize(eps, dest, means, vars, src, gamma, beta);
#else
cpu::layer_normalize(eps, dest, means, vars, src, gamma, beta);
#endif
}
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
)
{
cpu::layer_normalize_gradient(eps, gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
}
// ----------------------------------------------------------------------------------------
void threshold (
......
......@@ -802,6 +802,30 @@ namespace dlib { namespace tt
// -----------------------------------------------------------------------------------
void layer_normalize (
const double eps,
resizable_tensor& dest,
resizable_tensor& means,
resizable_tensor& invstds,
const tensor& src,
const tensor& gamma,
const tensor& beta
);
void layer_normalize_gradient (
const double eps,
const tensor& gradient_input,
const tensor& means,
const tensor& invstds,
const tensor& src,
const tensor& gamma,
tensor& src_grad,
tensor& gamma_grad,
tensor& beta_grad
);
// -----------------------------------------------------------------------------------
void threshold (
tensor& data,
float thresh
......
......@@ -1302,6 +1302,141 @@ namespace dlib
// ----------------------------------------------------------------------------------------
const double DEFAULT_LAYER_NORM_EPS = 1e-5;
class layer_norm_
{
public:
explicit layer_norm_(
double eps_ = DEFAULT_LAYER_NORM_EPS
) :
learning_rate_multiplier(1),
weight_decay_multiplier(0),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1),
eps(eps_)
{
}
double get_eps() const { return eps; }
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
inline dpoint map_input_to_output (const dpoint& p) const { return p; }
inline dpoint map_output_to_input (const dpoint& p) const { return p; }
template <typename SUBNET>
void setup (const SUBNET& sub)
{
gamma = alias_tensor(sub.get_output().num_samples());
beta = gamma;
params.set_size(gamma.size()+beta.size());
gamma(params,0) = 1;
beta(params,gamma.size()) = 0;
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto g = gamma(params,0);
auto b = beta(params,gamma.size());
tt::layer_normalize(eps, output, means, invstds, sub.get_output(), g, b);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{
auto g = gamma(params, 0);
auto g_grad = gamma(params_grad, 0);
auto b_grad = beta(params_grad, gamma.size());
tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad);
}
const tensor& get_layer_params() const { return params; };
tensor& get_layer_params() { return params; };
friend void serialize(const layer_norm_& item, std::ostream& out)
{
serialize("layer_norm_", out);
serialize(item.params, out);
serialize(item.gamma, out);
serialize(item.beta, out);
serialize(item.means, out);
serialize(item.invstds, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.eps, out);
}
friend void deserialize(layer_norm_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "layer_norm_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::layer_norm_.");
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.beta, in);
deserialize(item.means, in);
deserialize(item.invstds, in);
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
deserialize(item.eps, in);
}
friend std::ostream& operator<<(std::ostream& out, const layer_norm_& item)
{
out << "layer_norm";
out << " eps="<<item.eps;
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
return out;
}
friend void to_xml(const layer_norm_& item, std::ostream& out)
{
out << "layer_norm";
out << " eps='"<<item.eps<<"'";
out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
out << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
out << ">\n";
out << mat(item.params);
out << "</layer_norm>\n";
}
private:
resizable_tensor params;
alias_tensor gamma, beta;
resizable_tensor means, invstds;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
double eps;
};
template <typename SUBNET>
using layer_norm = add_layer<layer_norm_, SUBNET>;
// ----------------------------------------------------------------------------------------
enum layer_mode
{
CONV_MODE = 0,
......@@ -1577,24 +1712,61 @@ namespace dlib
unsigned long new_window_size;
};
class visitor_bn_input_no_bias
class visitor_disable_input_bias
{
public:
template <typename T>
void set_input_no_bias(T&) const
void disable_input_bias(T&) const
{
// ignore other layer types
}
// handle the standard case
template <typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
template <layer_mode mode, typename U, typename E>
void set_input_no_bias(add_layer<bn_<mode>, U, E>& l)
void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
{
disable_bias(l.subnet().layer_details());
set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
}
// handle input repeat layer case
template <layer_mode mode, size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<bn_<mode>, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
template <size_t N, template <typename> class R, typename U, typename E>
void disable_input_bias(add_layer<layer_norm_, repeat<N, R, U>, E>& l)
{
disable_bias(l.subnet().get_repeated_layer(0).layer_details());
set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
}
// handle input repeat layer with tag case
template <layer_mode mode, unsigned long ID, typename E, typename F>
void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer, E>, F>& )
{
}
template <unsigned long ID, typename E, typename F>
void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, impl::repeat_input_layer, E>, F>& )
{
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
......@@ -1604,7 +1776,7 @@ namespace dlib
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l)
{
set_input_no_bias(l);
disable_input_bias(l);
}
};
}
......@@ -1619,14 +1791,13 @@ namespace dlib
}
template <typename net_type>
void set_all_bn_inputs_no_bias (
void disable_duplicative_bias (
net_type& net
)
{
visit_layers(net, impl::visitor_bn_input_no_bias());
visit_layers(net, impl::visitor_disable_input_bias());
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
enum fc_bias_mode
......
......@@ -1434,6 +1434,149 @@ namespace dlib
template <typename SUBNET>
using multiply = add_layer<multiply_, SUBNET>;
// ----------------------------------------------------------------------------------------
const double DEFAULT_LAYER_NORM_EPS = 1e-5;
class layer_norm_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. In particular, it defines a batch normalization layer that
implements the method described in the paper:
Layer Normalization by Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
In particular, this layer produces output tensors with the same
dimensionality as the input tensors, except that the mean and variances of
the elements in each sample have been standardized to 0 and 1 respectively.
This is different from batch normalization, since this layer learns one scaling
factor and one bias for each sample in the batch, independently. As a result,
this layer is batch-size independent.
!*/
public:
layer_norm_(
);
/*!
ensures
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == DEFAULT_LAYER_NORM_EPS
!*/
explicit layer_norm_(
double eps_ = DEFAULT_LAYER_NORM_EPS
)
/*!
requires
- eps > 0
ensures
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == eps
!*/
double get_eps(
) const;
/*!
ensures
- When doing layer normalization, we are dividing by the standard
deviation. This epsilon value returned by this function is added to the
variance to prevent the division from dividing by zero.
!*/
double get_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its parameters be
multiplied by get_learning_rate_multiplier().
!*/
double get_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its parameters be
multiplied by get_weight_decay_multiplier().
!*/
void set_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_learning_rate_multiplier() == val
!*/
void set_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_weight_decay_multiplier() == val
!*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
dpoint map_input_to_output(dpoint p) const;
dpoint map_output_to_input(dpoint p) const;
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
// ----------------------------------------------------------------------------------------
enum layer_mode
......@@ -1667,7 +1810,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename net_type>
void set_all_bn_inputs_no_bias (
void disable_duplicative_bias (
const net_type& net
);
/*!
......@@ -1675,9 +1818,9 @@ namespace dlib
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Disables bias for all bn_ layer inputs.
- Disables bias for all bn_ and layer_norm_ inputs.
- Sets the get_bias_learning_rate_multiplier() and get_bias_weight_decay_multiplier()
to zero of all bn_ layer inputs.
to zero of all bn_ and layer_norm_ inputs.
!*/
// ----------------------------------------------------------------------------------------
......
......@@ -470,6 +470,60 @@ namespace
}
// ----------------------------------------------------------------------------------------
void test_layer_normalize()
{
resizable_tensor x(2, 3, 4, 5);
resizable_tensor y_cpu(x);
tt::tensor_rand rnd(0);
rnd.fill_uniform(x);
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
resizable_tensor gamma(x.num_samples()), beta(x.num_samples());
gamma = 1;
beta = 0;
const float eps = 1e-5;
cpu::layer_normalize(eps, y_cpu, means_cpu, invstds_cpu, x, gamma, beta);
// check that the mean and var per sample are 0 and 1
const float* p = y_cpu.host();
for (long n = 0; n < y_cpu.num_samples(); ++n)
{
running_stats<float> rs;
for (long k = 0; k < y_cpu.k(); ++k)
{
for (long r = 0; r < y_cpu.nr(); ++r)
{
for (long c = 0; c < y_cpu.nc(); ++c)
{
rs.add(p[tensor_index(y_cpu, n, k, r, c)]);
}
}
}
DLIB_TEST(::std::abs(rs.mean()) < 1e-6);
DLIB_TEST(::std::abs(rs.stddev() - 1.0f) < 0.01);
}
// check that the CPU and the CUDA implementation are equivalent
#if DLIB_USE_CUDA
resizable_tensor y_cuda(x);
resizable_tensor means_cuda(x.num_samples()), invstds_cuda(x.num_samples());
cuda::layer_normalize(eps, y_cuda, means_cuda, invstds_cuda, x, gamma, beta);
DLIB_TEST(max(abs(mat(y_cpu) - mat(y_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
resizable_tensor gradient_input(x);
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(x.num_samples()), beta_grad_cpu(x.num_samples());
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(x.num_samples()), beta_grad_cuda(x.num_samples());
rnd.fill_gaussian(gradient_input);
src_grad_cpu = 0;
src_grad_cuda = 0;
cpu::layer_normalize_gradient(eps, gradient_input, means_cpu, invstds_cpu, x, gamma, src_grad_cpu, gamma_grad_cpu, beta_grad_cpu);
cuda::layer_normalize_gradient(eps, gradient_input, means_cuda, invstds_cuda, x, gamma, src_grad_cuda, gamma_grad_cuda, beta_grad_cuda);
DLIB_TEST(max(abs(mat(src_grad_cpu) - mat(src_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(gamma_grad_cpu) - mat(gamma_grad_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(beta_grad_cpu) - mat(beta_grad_cuda))) < 1e-5);
#endif
}
// ----------------------------------------------------------------------------------------
void test_basic_tensor_ops()
......@@ -1816,6 +1870,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
layer_norm_ l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
cont_<3,3,3,2,2,0,0> l;
......@@ -3846,6 +3906,46 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
template <long num_filters, long ks, int s, typename SUBNET>
using conp = add_layer<con_<num_filters, ks, ks, s, s, ks/2, ks/2>, SUBNET>;
template <typename INPUT>
using stem = add_layer<max_pool_<3, 3, 2, 2, 1, 1>, relu<bn_con<conp<16, 7, 2, INPUT>>>>;
template <long num_filters, long growth_rate, typename SUBNET>
using dense_layer = concat2<tag1, tag2,
tag2<conp<growth_rate, 3, 1,
relu<bn_con<conp<4 * growth_rate, 1, 1,
relu<bn_con<tag1<SUBNET>>>>>>>>>;
template <typename SUBNET> using dense_layer_32 = dense_layer<32, 8, SUBNET>;
void test_disable_duplicative_bias()
{
using net_type = fc<10, relu<layer_norm<fc<15, relu<bn_fc<fc<20,
relu<layer_norm<conp<32, 3, 1,
repeat<2, dense_layer_32,
stem<input_rgb_image>>>>>>>>>>>>;
net_type net;
DLIB_TEST(layer<0>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<3>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<6>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<9>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<12>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<15>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<21>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<24>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<31>(net).layer_details().bias_is_disabled() == false);
disable_duplicative_bias(net);
DLIB_TEST(layer<0>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<3>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<6>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<9>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<12>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<15>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<21>(net).layer_details().bias_is_disabled() == false);
DLIB_TEST(layer<24>(net).layer_details().bias_is_disabled() == true);
DLIB_TEST(layer<31>(net).layer_details().bias_is_disabled() == true);
}
// ----------------------------------------------------------------------------------------
// This test really just checks if the mmod loss goes negative when a whole lot of overlapping
......@@ -4002,6 +4102,7 @@ namespace
test_gelu();
test_batch_normalize();
test_batch_normalize_conv();
test_layer_normalize();
test_basic_tensor_ops();
test_layers();
test_visit_functions();
......@@ -4029,6 +4130,7 @@ namespace
test_loss_multimulticlass_log();
test_loss_mmod();
test_layers_scale_and_scale_prev();
test_disable_duplicative_bias();
}
void perform_test()
......
......@@ -134,8 +134,8 @@ int main(int argc, char** argv) try
// setup all leaky_relu_ layers in the discriminator to have alpha = 0.2
visit_computational_layers(discriminator, [](leaky_relu_& l){ l = leaky_relu_(0.2); });
// Remove the bias learning from all bn_ inputs in both networks
set_all_bn_inputs_no_bias(generator);
set_all_bn_inputs_no_bias(discriminator);
disable_duplicative_bias(generator);
disable_duplicative_bias(discriminator);
// Forward random noise so that we see the tensor size at each layer
discriminator(generate_image(generator, make_noise(rnd)));
cout << "generator (" << count_parameters(generator) << " parameters)" << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment