Unverified Commit ff6b8bb0 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #383 from NVIDIA/lamb_add_fp16_support_update_term

Add support for fp16 update term (new UPD_T typename in template)
parents 18f2eaee 3aeea0d8
......@@ -14,7 +14,7 @@
#define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer.
template<typename GRAD_T, typename T>
template<typename GRAD_T, typename T, typename UPD_T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
......@@ -52,7 +52,7 @@ struct LAMBStage1Functor
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[4][tensor_loc];
UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -100,7 +100,7 @@ struct LAMBStage1Functor
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = r_p[ii];
update[i] = (UPD_T)r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
......@@ -129,19 +129,20 @@ void multi_tensor_lamb_stage1_cuda(
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); ))
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -15,7 +15,7 @@
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T>
template<typename T, typename UPD_T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
......@@ -42,7 +42,7 @@ struct LAMBStage2Functor
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[1][tensor_loc];
UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -52,7 +52,7 @@ struct LAMBStage2Functor
i_start += blockDim.x*ILP)
{
T r_p[ILP];
T r_update[ILP];
UPD_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
......@@ -66,7 +66,7 @@ struct LAMBStage2Functor
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*r_update[ii]);
r_p[ii] = r_p[ii] - (ratio*(T)r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
......@@ -92,15 +92,16 @@ void multi_tensor_lamb_stage2_cuda(
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage2Functor<scalar_t_0>(),
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.data<float>(),
per_tensor_update_norm.data<float>(),
learning_rate); )
learning_rate); ))
AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment