Commit 53cfd8c2 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Delayed init

parent 4c54fd2b
......@@ -65,30 +65,57 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False,
adam_w_mode=True, use_nvlamb=False, use_mt=False,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedLAMB does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._init_args = {
'lr': lr,
'bias_correction': bias_correction,
'grad_averaging': grad_averaging,
'betas': betas,
'eps': eps,
'weight_decay': weight_decay,
'max_grad_norm': max_grad_norm,
'adam_w_mode': adam_w_mode,
'use_nvlamb': use_nvlamb,
'amp_scale_adjustment': amp_scale_adjustment,
'overlap_reductions': overlap_reductions,
'dwu_group_size': dwu_group_size,
'dwu_num_blocks': dwu_num_blocks,
'dwu_num_chunks': dwu_num_chunks,
'dwu_num_rs_pg': dwu_num_rs_pg,
'dwu_num_ar_pg': dwu_num_ar_pg,
'dwu_num_ag_pg': dwu_num_ag_pg,
'e5m2_allgather': e5m2_allgather}
self._init_done = False
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def __init_on_demand(self,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False):
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
......@@ -362,8 +389,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def _first_time_init(self):
if not self._init_done:
self.__init_on_demand(**self._init_args)
self._init_done = False
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
......@@ -466,6 +495,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self.L2_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
print(self.L2_grad_norm,max_grad_norm,param_norm,upd_norm)
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
......@@ -482,7 +512,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(self._grads_fp16)),
list(zip(*self._grads_fp16)),
scale)
self._grads_fp16 = []
if len(self._grads_fp32) > 0:
......@@ -490,11 +520,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(self._grads_fp32)),
list(zip(*self._grads_fp32)),
scale)
self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
self._first_time_init()
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
......@@ -518,6 +549,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self._first_time_init()
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
......@@ -545,10 +577,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in self._param_group:
self._param_group['step'] += 1
else:
self._param_group['step'] = 1
for param_group in self.param_groups:
if 'step' in param_group:
param_group['step'] += 1
else:
param_group['step'] = 1
self._pipeline_step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment