Unverified Commit 17eec271 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

option to set param views to flat buffer (#1152)



* option to set param views to flat buffer

* remove redundant variables in init_stage1
Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
Co-authored-by: default avatarptrblck <ptrblck@users.noreply.github.com>
parent 2e98baa7
...@@ -88,7 +88,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -88,7 +88,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
e5m2_allgather=False, verbose=False, clip_after_ar=True, e5m2_allgather=False, verbose=False, clip_after_ar=True,
full_ar=False, fuse_scale=False): full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
fuse_scale=False):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
...@@ -123,6 +124,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -123,6 +124,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._full_ar = full_ar self._full_ar = full_ar
self._fuse_scale = fuse_scale self._fuse_scale = fuse_scale
self._L2_grad_norm = None self._L2_grad_norm = None
self._set_flat_param_view = set_param_views_to_flat_buffer
self._skip_ag = skip_allgather
self._fused_norm = fused_norm self._fused_norm = fused_norm
self._current_process_group = c10d._get_default_group() self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
...@@ -238,6 +241,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -238,6 +241,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if self._verbose: if self._verbose:
print(f"creating AG group : {ranks}") print(f"creating AG group : {ranks}")
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.barrier(group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream()
...@@ -252,9 +257,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -252,9 +257,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter() self._param_order = self.AtomicCounter()
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_offset = 0 p_offset = 0
p_i = 0 p_i = 0
self._model_params = [] self._model_params = []
...@@ -281,19 +283,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -281,19 +283,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps eps
)) ))
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i): if self._set_flat_param_view:
param_tmp = param.expand_as(param) self._param_order.add(p_i)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_offset += p_grads_size p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters: # RNN is one example of consecutive parameters:
...@@ -311,13 +302,46 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -311,13 +302,46 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._total_param_size = p_offset self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
self._do_overlapped_reduction(param_i, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint # initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None: if self._fp32_p is None:
...@@ -392,8 +416,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -392,8 +416,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _lazy_init_stage2(self): def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
self._param_order.order.reverse()
self._param_order.order.reverse() # re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param): def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last): if param.is_contiguous(memory_format=torch.channels_last):
...@@ -406,11 +435,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -406,11 +435,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
pv = param pv = param
return pv.view(-1) return pv.view(-1)
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
# re-collect grads info (size, offset) after ordering # re-collect grads info (size, offset) after ordering
prev = None prev = None
p_offset = 0 p_offset = 0
...@@ -749,7 +773,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -749,7 +773,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_weight_decay, self._contrib_weight_decay,
global_grad_norm, global_grad_norm,
self._use_nvlamb) self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) if not self._skip_ag:
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
...@@ -851,21 +876,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -851,21 +876,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
optimizer_state["found_inf_per_device"][current_device] = found_inf optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream()) self._completion_st.wait_stream(torch.cuda.current_stream())
if not self._set_flat_param_view:
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params # Copy self._new_params to model params
with torch.no_grad(): with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None: if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
self._packed_flat_to_model_params_fp16) self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None: if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
self._packed_flat_to_model_params_fp32) self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks self._reductions_works = [None]*self._num_blocks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment