Unverified Commit 3da3e811 authored by Adrià Arrufat's avatar Adrià Arrufat Committed by GitHub
Browse files

Fix Layer Normalize (#2489)

* Fix Layer Normalize

* remove unneeded temporary variables
parent aaac87a2
...@@ -1273,8 +1273,9 @@ namespace dlib ...@@ -1273,8 +1273,9 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc(); const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT( DLIB_CASSERT(
have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() && src.k() == gamma.k() &&
src.num_samples() == beta.size() && src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
eps > 0, eps > 0,
"\ngamma.k(): " << gamma.k() << "\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() << "\ngamma.nr(): " << gamma.nr() <<
...@@ -1282,9 +1283,9 @@ namespace dlib ...@@ -1282,9 +1283,9 @@ namespace dlib
"\nbeta.k(): " << beta.k() << "\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() << "\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() << "\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() << "\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() << "\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() << "\nsrc.nc(): " << src.nc() <<
"\neps: " << eps "\neps: " << eps
); );
...@@ -1329,7 +1330,7 @@ namespace dlib ...@@ -1329,7 +1330,7 @@ namespace dlib
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
*p_dest = (*p_src - p_means[n])*p_invstds[n]; *p_dest = (*p_src - p_means[n])*p_invstds[n];
*p_dest = (*p_dest)*p_gamma[n] + p_beta[n]; *p_dest = (*p_dest)*p_gamma[i] + p_beta[i];
++p_src; ++p_src;
++p_dest; ++p_dest;
} }
...@@ -1351,11 +1352,12 @@ namespace dlib ...@@ -1351,11 +1352,12 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc(); const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size()); DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size()); DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.num_samples() == gamma_grad.size()); DLIB_CASSERT(src.nr() == gamma_grad.nr());
DLIB_CASSERT(src.num_samples() == beta_grad.size()); DLIB_CASSERT(src.nc() == beta_grad.nc());
DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0); DLIB_CASSERT(eps > 0);
beta_grad = 0; beta_grad = 0;
...@@ -1381,12 +1383,12 @@ namespace dlib ...@@ -1381,12 +1383,12 @@ namespace dlib
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
const float x_hat = (*p_src - p_means[n])*p_invstds[n]; const float x_hat = (*p_src - p_means[n])*p_invstds[n];
p_beta_grad[n] += *p_grad; p_beta_grad[i] += *p_grad;
p_gamma_grad[n] += (*p_grad)*x_hat; p_gamma_grad[i] += (*p_grad)*x_hat;
const float dx = *p_grad * p_gamma[n]; const float dx = *p_grad * p_gamma[n];
p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*std::pow(p_invstds[n], 3.0f); p_dvars[n] += dx*(*p_src - p_means[n])*-0.5*p_invstds[n]*p_invstds[n]*p_invstds[n];
++p_grad; ++p_grad;
++p_src; ++p_src;
...@@ -1400,7 +1402,7 @@ namespace dlib ...@@ -1400,7 +1402,7 @@ namespace dlib
{ {
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
const float dx = *p_grad * p_gamma[n]; const float dx = *p_grad * p_gamma[i];
p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum; p_dmeans[n] += dx*-p_invstds[n] + p_dvars[n] * -2*(*p_src - p_means[n])*invnum;
...@@ -1415,7 +1417,7 @@ namespace dlib ...@@ -1415,7 +1417,7 @@ namespace dlib
{ {
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
const float dx = *p_grad * p_gamma[n]; const float dx = *p_grad * p_gamma[i];
*p_src_grad += dx*p_invstds[n] + *p_src_grad += dx*p_invstds[n] +
p_dvars[n] *2*(*p_src - p_means[n])*invnum + p_dvars[n] *2*(*p_src - p_means[n])*invnum +
......
...@@ -1908,7 +1908,7 @@ namespace dlib ...@@ -1908,7 +1908,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num)) for (auto i : grid_stride_range(0, num))
{ {
const float val = (s[n*num+i]-m[n])*v[n]; const float val = (s[n*num+i]-m[n])*v[n];
out[n*num+i] = val*g[n]+b[n]; out[n*num+i] = val*g[i]+b[i];
} }
} }
} }
...@@ -1917,21 +1917,17 @@ namespace dlib ...@@ -1917,21 +1917,17 @@ namespace dlib
{ {
for (auto n : grid_stride_range_y(0, ns)) for (auto n : grid_stride_range_y(0, ns))
{ {
float temp_bg = 0;
float temp_gg = 0;
float temp_dv = 0; float temp_dv = 0;
for (auto i : grid_stride_range(0, num)) for (auto i : grid_stride_range(0, num))
{ {
auto idx = n*num+i; auto idx = n*num+i;
const float x_hat = (s[idx] - m[n])*v[n]; const float x_hat = (s[idx] - m[n])*v[n];
temp_bg += gi[idx]; bg[i] += gi[idx];
temp_gg += gi[idx]*x_hat; gg[i] += gi[idx]*x_hat;
const float dx = gi[idx] * g[n]; const float dx = gi[idx] * g[n];
temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n]; temp_dv += dx*(s[idx] - m[n])*-0.5*v[n]*v[n]*v[n];
} }
warp_reduce_atomic_add(bg[n], temp_bg);
warp_reduce_atomic_add(gg[n], temp_gg);
warp_reduce_atomic_add(dv[n], temp_dv); warp_reduce_atomic_add(dv[n], temp_dv);
} }
__syncthreads(); __syncthreads();
...@@ -1942,7 +1938,7 @@ namespace dlib ...@@ -1942,7 +1938,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num)) for (auto i : grid_stride_range(0, num))
{ {
auto idx = n*num+i; auto idx = n*num+i;
const float dx = gi[idx]*g[n]; const float dx = gi[idx]*g[i];
temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num; temp_dm += dx*-v[n] + dv[n] * -2*(s[idx] - m[n])/num;
} }
warp_reduce_atomic_add(dm[n], temp_dm); warp_reduce_atomic_add(dm[n], temp_dm);
...@@ -1954,7 +1950,7 @@ namespace dlib ...@@ -1954,7 +1950,7 @@ namespace dlib
for (auto i : grid_stride_range(0, num)) for (auto i : grid_stride_range(0, num))
{ {
auto idx = n*num+i; auto idx = n*num+i;
const float dx = gi[idx]*g[n]; const float dx = gi[idx]*g[i];
out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num; out[idx] += dx*v[n] + dv[n] * 2*(s[idx] - m[n])/num + dm[n]/num;
} }
} }
...@@ -1973,8 +1969,9 @@ namespace dlib ...@@ -1973,8 +1969,9 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc(); const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT( DLIB_CASSERT(
have_same_dimensions(gamma, beta) && have_same_dimensions(gamma, beta) &&
src.num_samples() == gamma.size() && src.k() == gamma.k() &&
src.num_samples() == beta.size() && src.nr() == gamma.nr() &&
src.nc() == gamma.nc() &&
eps > 0, eps > 0,
"\ngamma.k(): " << gamma.k() << "\ngamma.k(): " << gamma.k() <<
"\ngamma.nr(): " << gamma.nr() << "\ngamma.nr(): " << gamma.nr() <<
...@@ -1982,9 +1979,9 @@ namespace dlib ...@@ -1982,9 +1979,9 @@ namespace dlib
"\nbeta.k(): " << beta.k() << "\nbeta.k(): " << beta.k() <<
"\nbeta.nr(): " << beta.nr() << "\nbeta.nr(): " << beta.nr() <<
"\nbeta.nc(): " << beta.nc() << "\nbeta.nc(): " << beta.nc() <<
"\nsrc.k(): " << src.k() << "\nsrc.k(): " << src.k() <<
"\nsrc.nr(): " << src.nr() << "\nsrc.nr(): " << src.nr() <<
"\nsrc.nc(): " << src.nc() << "\nsrc.nc(): " << src.nc() <<
"\neps: " << eps "\neps: " << eps
); );
...@@ -2012,11 +2009,13 @@ namespace dlib ...@@ -2012,11 +2009,13 @@ namespace dlib
const long num = src.k() * src.nr() * src.nc(); const long num = src.k() * src.nr() * src.nc();
DLIB_CASSERT(src.num_samples() == means.size()); DLIB_CASSERT(src.num_samples() == means.size());
DLIB_CASSERT(src.num_samples() == invstds.size()); DLIB_CASSERT(src.num_samples() == invstds.size());
DLIB_CASSERT(src.num_samples() == gamma.size()); DLIB_CASSERT(src.k() == gamma.k());
DLIB_CASSERT(src.num_samples() == gamma_grad.size()); DLIB_CASSERT(src.nr() == gamma.nr());
DLIB_CASSERT(src.num_samples() == beta_grad.size()); DLIB_CASSERT(src.nc() == gamma.nc());
DLIB_CASSERT(have_same_dimensions(gradient_input, src)); DLIB_CASSERT(have_same_dimensions(gradient_input, src));
DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad));
DLIB_CASSERT(have_same_dimensions(gamma_grad, gamma));
DLIB_CASSERT(have_same_dimensions(gamma_grad, beta_grad));
DLIB_CASSERT(eps > 0); DLIB_CASSERT(eps > 0);
beta_grad = 0; beta_grad = 0;
......
...@@ -1371,7 +1371,7 @@ namespace dlib ...@@ -1371,7 +1371,7 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
{ {
gamma = alias_tensor(sub.get_output().num_samples()); gamma = alias_tensor(1, sub.get_output().k(), sub.get_output().nr(), sub.get_output().nc());
beta = gamma; beta = gamma;
params.set_size(gamma.size()+beta.size()); params.set_size(gamma.size()+beta.size());
......
...@@ -556,7 +556,7 @@ namespace ...@@ -556,7 +556,7 @@ namespace
tt::tensor_rand rnd(0); tt::tensor_rand rnd(0);
rnd.fill_uniform(x); rnd.fill_uniform(x);
resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples()); resizable_tensor means_cpu(x.num_samples()), invstds_cpu(x.num_samples());
resizable_tensor gamma(x.num_samples()), beta(x.num_samples()); resizable_tensor gamma(1, x.k(), x.nr(), x.nc()), beta(1, x.k(), x.nr(), x.nc());
gamma = 1; gamma = 1;
beta = 0; beta = 0;
const float eps = 1e-5; const float eps = 1e-5;
...@@ -588,8 +588,8 @@ namespace ...@@ -588,8 +588,8 @@ namespace
DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5); DLIB_TEST(max(abs(mat(means_cpu) - mat(means_cuda))) < 1e-5);
DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5); DLIB_TEST(max(abs(mat(invstds_cpu) - mat(invstds_cuda))) < 1e-5);
resizable_tensor gradient_input(x); resizable_tensor gradient_input(x);
resizable_tensor src_grad_cpu(x), gamma_grad_cpu(x.num_samples()), beta_grad_cpu(x.num_samples()); resizable_tensor src_grad_cpu(x), gamma_grad_cpu(1, x.k(), x.nr(), x.nc()), beta_grad_cpu(1, x.k(), x.nr(), x.nc());
resizable_tensor src_grad_cuda(x), gamma_grad_cuda(x.num_samples()), beta_grad_cuda(x.num_samples()); resizable_tensor src_grad_cuda(x), gamma_grad_cuda(1, x.k(), x.nr(), x.nc()), beta_grad_cuda(1, x.k(), x.nr(), x.nc());
rnd.fill_gaussian(gradient_input); rnd.fill_gaussian(gradient_input);
src_grad_cpu = 0; src_grad_cpu = 0;
src_grad_cuda = 0; src_grad_cuda = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment