Unverified Commit f40c8df0 authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

Simplify EMA to use Pytorch's update_parameters (#5469)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 5568744c
...@@ -166,17 +166,7 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): ...@@ -166,17 +166,7 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
def ema_avg(avg_model_param, model_param, num_averaged): def ema_avg(avg_model_param, model_param, num_averaged):
return decay * avg_model_param + (1 - decay) * model_param return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg) super().__init__(model, device, ema_avg, use_buffers=True)
def update_parameters(self, model):
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device)))
self.n_averaged += 1
def accuracy(output, target, topk=(1,)): def accuracy(output, target, topk=(1,)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment