Commit 1dcff6aa authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update EMA

parent 12aa565e
...@@ -3,6 +3,8 @@ import copy ...@@ -3,6 +3,8 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.utils.tensor_utils import tensor_tree_map
class ExponentialMovingAverage: class ExponentialMovingAverage:
""" """
...@@ -27,8 +29,14 @@ class ExponentialMovingAverage: ...@@ -27,8 +29,14 @@ class ExponentialMovingAverage:
""" """
super(ExponentialMovingAverage, self).__init__() super(ExponentialMovingAverage, self).__init__()
self.params = copy.deepcopy(model.state_dict()) clone_param = lambda t: t.clone().detach()
self.params = tensor_tree_map(clone_param, model.state_dict())
self.decay = decay self.decay = decay
self.device = next(model.parameters()).device
def to(self, device):
self.params = tensor_tree_map(lambda t: t.to(device), self.params)
self.device = device
def _update_state_dict_(self, update, state_dict): def _update_state_dict_(self, update, state_dict):
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment