Commit 5caf95ca authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Make step global state variable

parent 7741808b
......@@ -135,11 +135,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
1.0 - beta1 if grad_averaging else 1.0,
group['eps']
))
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
......@@ -463,7 +458,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_beta2,
self._contrib_beta3,
self._contrib_bias_correction,
self._param_state['step']+1,
self._param_group['step'],
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
......@@ -547,11 +542,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if closure is not None:
loss = closure()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in self._param_group:
self._param_group['step'] += 1
else:
self._param_group['step'] = 1
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment