Commit a9f5f711 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 41c98511 121a2500
...@@ -62,19 +62,47 @@ struct LAMBStage1Functor ...@@ -62,19 +62,47 @@ struct LAMBStage1Functor
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
{ {
GRAD_T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ {
T scaled_grad = g[i] / clipped_global_grad_norm; r_g[ii] = g[i];
m[i] = m[i] * beta1 + (1-beta1) * scaled_grad; r_p[ii] = p[i];
v[i] = v[i] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_m[ii] = m[i];
T next_m_unbiased = m[i] / beta1_correction; r_v[ii] = v[i];
T next_v_unbiased = v[i] / beta2_correction; } else {
T denom = std::sqrt(next_v_unbiased) + epsilon; r_g[ii] = GRAD_T(0);
update[i] = (next_m_unbiased/denom) + (decay*p[i]); r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = r_m[ii] / beta1_correction;
T next_v_unbiased = r_v[ii] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
} }
} }
} }
......
...@@ -51,14 +51,30 @@ struct LAMBStage2Functor ...@@ -51,14 +51,30 @@ struct LAMBStage2Functor
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
{ {
// see note in multi_tensor_scale_kernel.cu T r_p[ILP];
T r_update[ILP];
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ {
p[i] = p[i] - (ratio*update[i]); r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment