Unverified Commit ff6b8bb0 authored by Thor Johnsen's avatar Thor Johnsen Committed by GitHub
Browse files

Merge pull request #383 from NVIDIA/lamb_add_fp16_support_update_term

Add support for fp16 update term (new UPD_T typename in template)
parents 18f2eaee 3aeea0d8
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#define ILP 4 #define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer. // Step 1 computes the 'update' value of regular Adam optimizer.
template<typename GRAD_T, typename T> template<typename GRAD_T, typename T, typename UPD_T>
struct LAMBStage1Functor struct LAMBStage1Functor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -52,7 +52,7 @@ struct LAMBStage1Functor ...@@ -52,7 +52,7 @@ struct LAMBStage1Functor
T* v = (T*)tl.addresses[3][tensor_loc]; T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[4][tensor_loc]; UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -100,7 +100,7 @@ struct LAMBStage1Functor ...@@ -100,7 +100,7 @@ struct LAMBStage1Functor
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{ {
update[i] = r_p[ii]; update[i] = (UPD_T)r_p[ii];
m[i] = r_m[ii]; m[i] = r_m[ii];
v[i] = r_v[ii]; v[i] = r_v[ii];
} }
...@@ -129,19 +129,20 @@ void multi_tensor_lamb_stage1_cuda( ...@@ -129,19 +129,20 @@ void multi_tensor_lamb_stage1_cuda(
float beta2_correction = 1.0f - std::pow(beta2, next_step); float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1>(), LAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_decay.data<float>(), per_tensor_decay.data<float>(),
beta1, beta1,
beta2, beta2,
beta1_correction, beta1_correction,
beta2_correction, beta2_correction,
epsilon, epsilon,
clipped_global_grad_norm); )) clipped_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value. // It computes new parameter value.
template<typename T> template<typename T, typename UPD_T>
struct LAMBStage2Functor struct LAMBStage2Functor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -42,7 +42,7 @@ struct LAMBStage2Functor ...@@ -42,7 +42,7 @@ struct LAMBStage2Functor
T* p = (T*)tl.addresses[0][tensor_loc]; T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[1][tensor_loc]; UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size; update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
...@@ -52,7 +52,7 @@ struct LAMBStage2Functor ...@@ -52,7 +52,7 @@ struct LAMBStage2Functor
i_start += blockDim.x*ILP) i_start += blockDim.x*ILP)
{ {
T r_p[ILP]; T r_p[ILP];
T r_update[ILP]; UPD_T r_update[ILP];
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
...@@ -66,7 +66,7 @@ struct LAMBStage2Functor ...@@ -66,7 +66,7 @@ struct LAMBStage2Functor
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
r_p[ii] = r_p[ii] - (ratio*r_update[ii]); r_p[ii] = r_p[ii] - (ratio*(T)r_update[ii]);
} }
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
...@@ -92,15 +92,16 @@ void multi_tensor_lamb_stage2_cuda( ...@@ -92,15 +92,16 @@ void multi_tensor_lamb_stage2_cuda(
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2",
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
LAMBStage2Functor<scalar_t_0>(), LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.data<float>(), per_tensor_param_norm.data<float>(),
per_tensor_update_norm.data<float>(), per_tensor_update_norm.data<float>(),
learning_rate); ) learning_rate); ))
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment