Unverified Commit 8a80d478 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Distributed LAMB fixes (#1007)



* add flag for DistributedAdam: step_support_amp_scaling
Co-authored-by: default avatarKexin Yu <kexiny@nvidia.com>
Co-authored-by: default avatarKexin Yu <kexinznzn›@gmail.com>
parent 3fe10b55
......@@ -12,8 +12,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm);
const float grad_scale);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
......
......@@ -120,8 +120,7 @@ struct DistOptLAMBStage1Functor
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm,
const MATH_T max_global_grad_norm)
const float grad_scale)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -132,15 +131,13 @@ struct DistOptLAMBStage1Functor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta2[tensor_num];
MATH_T beta3 = 1 - beta1;
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step);
beta2_correction = 1 - pow(beta2, (MATH_T) step);
beta1_correction = 1 - pow(beta1, step);
beta2_correction = 1 - pow(beta2, step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
......@@ -207,7 +204,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -218,7 +215,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -277,7 +274,7 @@ struct DistOptLAMBStage1Functor
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
......@@ -288,7 +285,7 @@ struct DistOptLAMBStage1Functor
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
MATH_T scaled_grad = r_g[ii] / grad_scale;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
......@@ -346,7 +343,7 @@ struct DistOptLAMBStage2Functor
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
ratio = (update_norm != 0.0 && param_norm != 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
......@@ -434,8 +431,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm)
const float grad_scale)
{
using namespace at;
......@@ -456,8 +452,7 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm,
(scalar_t_2) max_global_grad_norm); )))
grad_scale); )))
AT_CUDA_CHECK(cudaGetLastError());
}
......
......@@ -56,6 +56,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm (boolean, optional): whether to handle gradient clipping
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
......@@ -67,7 +69,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
......@@ -89,6 +91,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'clip_grad_norm': clip_grad_norm,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
......@@ -107,7 +110,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
adam_w_mode=True, use_nvlamb=False, clip_grad_norm=True,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
......@@ -127,9 +130,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._clip_grad_norm = clip_grad_norm
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
......@@ -468,9 +473,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
if self._clip_grad_norm and max_grad_norm > 0 and math.isfinite(global_grad_norm):
combined_scale = max_grad_norm / (global_grad_norm / self.global_scale + 1e-6)
combined_scale = self.global_scale / min(1, combined_scale)
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
......@@ -478,7 +497,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm()
max_grad_norm = self.defaults['max_grad_norm']
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
......@@ -490,8 +508,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
self.L2_grad_norm,
max_grad_norm)
combined_scale)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
......@@ -537,6 +554,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment