Unverified Commit a47d1a76 authored by chochowski's avatar chochowski Committed by GitHub
Browse files

fix group range to compute l2_norm (#1266)



* fix graph capture failure, fix norm computation with full_ar and clip_after

* fix group range to compute l2_norm
Co-authored-by: default avatarseryilmaz <seryilmaz@nvidia.com>
Co-authored-by: default avatarmchochowski <mchochowski@nvidia.com>
parent 2eafdb3d
...@@ -127,7 +127,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -127,7 +127,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._L2_grad_norm = None self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather self._skip_ag = skip_allgather
self._fused_norm = fused_norm self._fused_norm = fused_norm if not clip_after_ar else False
self._current_process_group = c10d._get_default_group() self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
...@@ -151,12 +151,16 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -151,12 +151,16 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_ag_pg = dwu_num_ag_pg self._num_ag_pg = dwu_num_ag_pg
if self._full_ar: # full all reduce, only need AR and AG groups if self._full_ar: # full all reduce, only need AR and AG groups
# l2_grad_norm may be reduced within a node to limit from memory reads
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+j for j in range(self._group_size)]
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
self._ar_pg = [] self._ar_pg = []
# consider all the ranks # consider all the ranks
ranks = list(range(0, self._world_size)) ranks = list(range(0, self._world_size))
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
for i in range(self._num_ar_pg): for i in range(self._num_ar_pg):
if self._verbose: if self._verbose:
print(f"creating new AR group {i}: {ranks}") print(f"creating new AR group {i}: {ranks}")
...@@ -651,7 +655,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -651,7 +655,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
glob_chunk_id = block_id * self._num_chunks + chunk_id glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream): with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True) works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
...@@ -682,8 +685,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -682,8 +685,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works[block_id][chunk_id].wait() self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed # Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda') l2_grad_norm_sq = torch.empty([1], device='cuda')
if 0:#self._full_ar: if self._full_ar:
l2_grad_norm_sq = self._flat_grads_shards[self._rank_in_group].norm(dtype=torch.float32, p=2)**2 # this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation
flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]
l2_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [flat_list], False)[0]**2
else: else:
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment