Unverified Commit c1f1e226 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Added update_parameters to EMA to fix calculation (#4406)

parent 9fa689b2
......@@ -172,6 +172,17 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)
def update_parameters(self, model):
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment