Unverified Commit bb791585 authored by Kexin Yu's avatar Kexin Yu Committed by GitHub
Browse files

DistributedFusedLAMB: enable no_copy and add barrier for SHARP (#1075)



* enable no_copy

* barrier for SHARP

* set verbose=False by default
Co-authored-by: default avatarKexin Yu <kexiny@nvidia.com>
parent 59d2f7ac
...@@ -87,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -87,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
step_supports_amp_scaling=True, overlap_reductions=True, step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False): e5m2_allgather=False, verbose=False):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
...@@ -117,6 +117,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -117,6 +117,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._L2_grad_norm = None self._L2_grad_norm = None
self._current_process_group = c10d._get_default_group() self._current_process_group = c10d._get_default_group()
...@@ -134,8 +135,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -134,8 +135,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# Master weight, moment, gradient buffers # Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
#import inspect import inspect
#assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg self._num_ar_pg = dwu_num_ar_pg
...@@ -145,7 +146,16 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -145,7 +146,16 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
for dev_i in range(self._group_size): for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg): for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new group {i}")
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
...@@ -471,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -471,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream): with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True)#,no_copy=True) works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank # Reduction across nodes for each rank
if self._num_groups > 1: if self._num_groups > 1:
...@@ -562,7 +572,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -562,7 +572,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_weight_decay, self._contrib_weight_decay,
global_grad_norm, global_grad_norm,
self._use_nvlamb) self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0])#, no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment