Commit 121a2500 authored by Thor Johnsen's avatar Thor Johnsen Committed by mcarilli
Browse files

Separate LDG/STG from compute loop (#359)

parent 4a9c2a53
......@@ -62,19 +62,47 @@ struct LAMBStage1Functor
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
GRAD_T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
T scaled_grad = g[i] / clipped_global_grad_norm;
m[i] = m[i] * beta1 + (1-beta1) * scaled_grad;
v[i] = v[i] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = m[i] / beta1_correction;
T next_v_unbiased = v[i] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
update[i] = (next_m_unbiased/denom) + (decay*p[i]);
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = GRAD_T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = r_m[ii] / beta1_correction;
T next_v_unbiased = r_v[ii] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
......
......@@ -51,14 +51,30 @@ struct LAMBStage2Functor
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
// see note in multi_tensor_scale_kernel.cu
T r_p[ILP];
T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = p[i] - (ratio*update[i]);
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment