Unverified Commit a6e2c1fe authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix ema decay (#1868)

* Fix ema decay and clarify nomenclature.

* Rename var.
parent b28ab302
...@@ -278,24 +278,19 @@ class EMAModel: ...@@ -278,24 +278,19 @@ class EMAModel:
self.decay = decay self.decay = decay
self.optimization_step = 0 self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad() @torch.no_grad()
def step(self, parameters): def step(self, parameters):
parameters = list(parameters) parameters = list(parameters)
self.optimization_step += 1 self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
# Compute the decay factor for the exponential moving average.
value = (1 + self.optimization_step) / (10 + self.optimization_step)
one_minus_decay = 1 - min(self.decay, value)
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad: if param.requires_grad:
tmp = self.decay * (s_param - param) s_param.sub_(one_minus_decay * (s_param - param))
s_param.sub_(tmp)
else: else:
s_param.copy_(param) s_param.copy_(param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment