"vscode:/vscode.git/clone" did not exist on "49a198c99cdf61cf869ced2dc1e4e8b69926ceed"
Commit 75f1e9d7 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Delayed init

parent 53cfd8c2
...@@ -67,7 +67,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -67,7 +67,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -76,6 +76,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -76,6 +76,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults) super(DistributedFusedLAMB, self).__init__(params, defaults)
self._init_args = { self._init_args = {
...@@ -102,11 +103,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -102,11 +103,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import inspect import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def __init_on_demand(self, def __first_step_init__(self,
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0., weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False, adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -144,10 +145,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -144,10 +145,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_info = [] self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = [] self._group_properties = []
self._param_state = None
self._param_group = None
for group in self.param_groups: for group in self.param_groups:
if self._param_group is None: self._param_group = group
prev = None prev = None
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
for p in group['params']: for p in group['params']:
...@@ -389,10 +387,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -389,10 +387,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works = [None]*self._num_blocks self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks self._allgather_works = [None]*self._num_blocks
def _first_time_init(self): def _init_everything(self):
if not self._init_done: if not self._init_done:
self.__init_on_demand(**self._init_args) self.__first_step_init__(**self._init_args)
self._init_done = False self._init_done = True
def set_is_accumulation_step(self, is_accumulation_step): def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step self._is_accumulation_step = is_accumulation_step
...@@ -488,20 +486,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -488,20 +486,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2, self._contrib_beta2,
self._contrib_beta3, self._contrib_beta3,
self._contrib_bias_correction, self._contrib_bias_correction,
self._param_group['step'], self.param_groups[0]['step'],
self._contrib_epsilon, self._contrib_epsilon,
self._adam_w_mode, self._adam_w_mode,
self._contrib_weight_decay, self._contrib_weight_decay,
self.L2_grad_norm, self.L2_grad_norm,
max_grad_norm) max_grad_norm)
upd_norm = self.__compute_contrib_update_norm() upd_norm = self.__compute_contrib_update_norm()
print(self.L2_grad_norm,max_grad_norm,param_norm,upd_norm)
multi_tensor_applier(self.multi_tensor_lamb_update_weights, multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf, self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm, param_norm,
upd_norm, upd_norm,
self._param_group['lr'], self.param_groups[0]['lr'],
self._contrib_weight_decay, self._contrib_weight_decay,
self._use_nvlamb) self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
...@@ -525,7 +522,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -525,7 +522,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
self._first_time_init() self._init_everything()
if not self._is_accumulation_step: if not self._is_accumulation_step:
# handle overlapped reductions # handle overlapped reductions
if param.dtype == torch.float16: if param.dtype == torch.float16:
...@@ -549,7 +546,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -549,7 +546,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
"""Complete reductions if full pipeline is not selected or overlap is not allowed. """Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._first_time_init() self._init_everything()
if self._last_step: if self._last_step:
# zero out gradients that have not been completed yet # zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated): for param_i, grad_generated in enumerate(self._grads_generated):
...@@ -587,16 +584,18 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -587,16 +584,18 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params # Copy self._new_params to model params
if self._packed_flat_to_model_params_fp16 is not None: self._overflow_buf.zero_()
multi_tensor_applier( with torch.no_grad():
fused_adam_cuda.maybe_cast_mt, if self._packed_flat_to_model_params_fp16 is not None:
self._overflow_buf, multi_tensor_applier(
self._packed_flat_to_model_params_fp16) fused_adam_cuda.maybe_cast_mt,
if self._packed_flat_to_model_params_fp32 is not None: self._overflow_buf,
multi_tensor_applier( self._packed_flat_to_model_params_fp16)
fused_adam_cuda.maybe_cast_mt, if self._packed_flat_to_model_params_fp32 is not None:
self._overflow_buf, multi_tensor_applier(
self._packed_flat_to_model_params_fp32) fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st) torch.cuda.current_stream().wait_stream(self._completion_st)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment