Commit 255313d8 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Use LERP to implement EMA

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/493

Currently the EMA implementation first does the multiplication and then does the addition. It requires two round trips from HBM. With the lerp operator, one kernel can do both. This change uses LERP to compute EMA instead. It reduces the GPU EMA computation time by 40%.

Reviewed By: newstzpz

Differential Revision: D43525938

fbshipit-source-id: ca1e14453bdfda958d3c412a52ff48efa65b3dd4
parent fd0cbb8f
...@@ -108,11 +108,19 @@ class EMAUpdater(object): ...@@ -108,11 +108,19 @@ class EMAUpdater(object):
Also, bn sync should be switched on for EMA. Also, bn sync should be switched on for EMA.
""" """
def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""): def __init__(
self,
state: EMAState,
decay: float = 0.999,
device: str = "",
use_lerp: bool = False,
):
self.decay = decay self.decay = decay
self.device = device self.device = device
self.state = state self.state = state
self.use_lerp = use_lerp
self.debug_lerp = False
def init_state(self, model): def init_state(self, model):
self.state.clear() self.state.clear()
...@@ -143,10 +151,28 @@ class EMAUpdater(object): ...@@ -143,10 +151,28 @@ class EMAUpdater(object):
Function to perform exponential moving average: Function to perform exponential moving average:
x_avg = alpha * x_avg + (1-alpha)* x_t x_avg = alpha * x_avg + (1-alpha)* x_t
""" """
torch._foreach_mul_(averaged_model_parameters, decay) if self.use_lerp:
torch._foreach_add_( if self.debug_lerp:
averaged_model_parameters, model_parameters, alpha=1 - decay orig_averaged_model_parameters = torch._foreach_mul(
) averaged_model_parameters, decay
)
torch._foreach_add_(
orig_averaged_model_parameters, model_parameters, alpha=1 - decay
)
torch._foreach_lerp_(
averaged_model_parameters, model_parameters, 1.0 - decay
)
if self.debug_lerp:
for (orig_val, lerp_val) in zip(
orig_averaged_model_parameters, averaged_model_parameters
):
assert torch.allclose(orig_val, lerp_val, rtol=1e-4, atol=1e-3)
else:
torch._foreach_mul_(averaged_model_parameters, decay)
torch._foreach_add_(
averaged_model_parameters, model_parameters, alpha=1 - decay
)
def add_model_ema_configs(_C): def add_model_ema_configs(_C):
...@@ -157,6 +183,8 @@ def add_model_ema_configs(_C): ...@@ -157,6 +183,8 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA.DEVICE = "" _C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model() # When True, loading the ema weight to the model when eval_only=True in build_model()
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False _C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
# Whether to use LERP to compute EMA
_C.MODEL_EMA.USE_LERP = False
def _remove_ddp(model): def _remove_ddp(model):
...@@ -239,7 +267,10 @@ class EMAHook(HookBase): ...@@ -239,7 +267,10 @@ class EMAHook(HookBase):
self.ema = self.model.ema_state self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater( self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device self.model.ema_state,
decay=cfg.MODEL_EMA.DECAY,
device=self.device,
use_lerp=cfg.MODEL_EMA.USE_LERP,
) )
def before_train(self): def before_train(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment