Unverified Commit ebcd7f08 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

Distributed LAMB: Clip grads before reduce_scatter/all_reduce (#1099)

* clip before reduce scatter

* provide clip before/after RS option

* change to clip after ar (avoid confusion)

* fix comments
parent 00c1e56d
......@@ -87,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False, verbose=False):
e5m2_allgather=False, verbose=False, clip_after_ar=True):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
......@@ -118,6 +118,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._L2_grad_norm = None
self._current_process_group = c10d._get_default_group()
......@@ -470,6 +471,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block
def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size)
# Reduction within each node
......@@ -504,6 +506,52 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else:
# Copy model grads to flat grads buffer
self._flatten_grad_mt(1.0)
# Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
# Apply clipping & pre-reduction scaling on grads
loss_scale = self.global_scale
max_grad_norm = loss_scale*self.defaults['max_grad_norm']
coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
coeff = (coeff>1) * self._one + (coeff<=1) * coeff
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
self._flat_grads.mul_(scale)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if block_id == 0:
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
......@@ -528,12 +576,20 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _pipeline_step(self):
global_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
# if clip before ar, set max_grad_norm to 0
max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
self._completion_st.wait_stream(self._l2_grad_norm_st)
global_grad_norm = self.L2_grad_norm
# check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow
self._step += is_finite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment