Commit 5caf95ca authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Make step global state variable

parent 7741808b
...@@ -135,11 +135,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -135,11 +135,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
1.0 - beta1 if grad_averaging else 1.0, 1.0 - beta1 if grad_averaging else 1.0,
group['eps'] group['eps']
)) ))
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param) param_tmp = param.expand_as(param)
...@@ -463,7 +458,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -463,7 +458,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2, self._contrib_beta2,
self._contrib_beta3, self._contrib_beta3,
self._contrib_bias_correction, self._contrib_bias_correction,
self._param_state['step']+1, self._param_group['step'],
self._contrib_epsilon, self._contrib_epsilon,
self._adam_w_mode, self._adam_w_mode,
self._contrib_weight_decay, self._contrib_weight_decay,
...@@ -547,11 +542,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -547,11 +542,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in self._param_group:
self._param_group['step'] += 1
else:
self._param_group['step'] = 1
self._pipeline_step() self._pipeline_step()
with torch.cuda.stream(self._completion_st): with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params # Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
if self._packed_flat_to_model_params_fp16 is not None: if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt, fused_adam_cuda.maybe_cast_mt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment