"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "da38e96abfacc93f6e2fb0d7b9a141ab03435b9c"
Commit dd06b323 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add barebones EMA

parent ff969b98
from collections import OrderedDict
import copy
import torch
import torch.nn as nn
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def __init__(self, model: nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
self.params = copy.deepcopy(model.state_dict())
self.decay = decay
def _update_state_dict_(self, update, state_dict):
for k, v in update.items():
stored = state_dict[k]
if(not isinstance(v, torch.Tensor)):
self._update_state_dict_(v, stored)
else:
diff = stored - v
diff *= (1 - self.decay)
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: OrderedDict) -> None:
self.params = state_dict["params"]
self.decay = state_dict["decay"]
def state_dict(self) -> OrderedDict:
return OrderedDict({
"params": self.params,
"decay": self.decay,
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment