Commit 75f1e9d7 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Delayed init

parent 53cfd8c2
......@@ -76,6 +76,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
self._init_args = {
......@@ -102,7 +103,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def __init_on_demand(self,
def __first_step_init__(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
......@@ -144,10 +145,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_info = []
self._grad_accs = []
self._group_properties = []
self._param_state = None
self._param_group = None
for group in self.param_groups:
if self._param_group is None: self._param_group = group
prev = None
beta1, beta2 = group['betas']
for p in group['params']:
......@@ -389,10 +387,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
def _first_time_init(self):
def _init_everything(self):
if not self._init_done:
self.__init_on_demand(**self._init_args)
self._init_done = False
self.__first_step_init__(**self._init_args)
self._init_done = True
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
......@@ -488,20 +486,19 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2,
self._contrib_beta3,
self._contrib_bias_correction,
self._param_group['step'],
self.param_groups[0]['step'],
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
self.L2_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
print(self.L2_grad_norm,max_grad_norm,param_norm,upd_norm)
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm,
upd_norm,
self._param_group['lr'],
self.param_groups[0]['lr'],
self._contrib_weight_decay,
self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
......@@ -525,7 +522,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
self._first_time_init()
self._init_everything()
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
......@@ -549,7 +546,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self._first_time_init()
self._init_everything()
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
......@@ -587,6 +584,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
self._overflow_buf.zero_()
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment