"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "391cfcd7d7e3df50ba30b3771c4347848ff0b2e1"
Commit eb184a78 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Parallelize EMA optimizer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/451

Tracing d2go runners using adamw optimizer yielded small operators being executed in the EMA code. They can be fused together by using multi-tensor API.

Reviewed By: tglik

Differential Revision: D42098310

fbshipit-source-id: 544d7e214964530ec03674986827410b0f60951f
parent 554b6992
...@@ -6,6 +6,7 @@ import copy ...@@ -6,6 +6,7 @@ import copy
import itertools import itertools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import List
import torch import torch
from detectron2.engine.train_loop import HookBase from detectron2.engine.train_loop import HookBase
...@@ -119,11 +120,33 @@ class EMAUpdater(object): ...@@ -119,11 +120,33 @@ class EMAUpdater(object):
def update(self, model): def update(self, model):
with torch.no_grad(): with torch.no_grad():
ema_param_list = []
param_list = []
for name, val in self.state.get_model_state_iterator(model): for name, val in self.state.get_model_state_iterator(model):
ema_val = self.state.state[name] ema_val = self.state.state[name]
if self.device: if self.device:
val = val.to(self.device) val = val.to(self.device)
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay)) if val.dtype in [torch.float32, torch.float16]:
ema_param_list.append(ema_val)
param_list.append(val)
else:
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
self._ema_avg(ema_param_list, param_list, self.decay)
def _ema_avg(
self,
averaged_model_parameters: List[torch.Tensor],
model_parameters: List[torch.Tensor],
decay: float,
) -> None:
"""
Function to perform exponential moving average:
x_avg = alpha * x_avg + (1-alpha)* x_t
"""
torch._foreach_mul_(averaged_model_parameters, decay)
torch._foreach_add_(
averaged_model_parameters, model_parameters, alpha=1 - decay
)
def add_model_ema_configs(_C): def add_model_ema_configs(_C):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment