Commit 255313d8 authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Use LERP to implement EMA

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/493

Currently the EMA implementation first does the multiplication and then does the addition. It requires two round trips from HBM. With the lerp operator, one kernel can do both. This change uses LERP to compute EMA instead. It reduces the GPU EMA computation time by 40%.

Reviewed By: newstzpz

Differential Revision: D43525938

fbshipit-source-id: ca1e14453bdfda958d3c412a52ff48efa65b3dd4
parent fd0cbb8f
......@@ -108,11 +108,19 @@ class EMAUpdater(object):
Also, bn sync should be switched on for EMA.
"""
def __init__(self, state: EMAState, decay: float = 0.999, device: str = ""):
def __init__(
self,
state: EMAState,
decay: float = 0.999,
device: str = "",
use_lerp: bool = False,
):
self.decay = decay
self.device = device
self.state = state
self.use_lerp = use_lerp
self.debug_lerp = False
def init_state(self, model):
self.state.clear()
......@@ -143,6 +151,24 @@ class EMAUpdater(object):
Function to perform exponential moving average:
x_avg = alpha * x_avg + (1-alpha)* x_t
"""
if self.use_lerp:
if self.debug_lerp:
orig_averaged_model_parameters = torch._foreach_mul(
averaged_model_parameters, decay
)
torch._foreach_add_(
orig_averaged_model_parameters, model_parameters, alpha=1 - decay
)
torch._foreach_lerp_(
averaged_model_parameters, model_parameters, 1.0 - decay
)
if self.debug_lerp:
for (orig_val, lerp_val) in zip(
orig_averaged_model_parameters, averaged_model_parameters
):
assert torch.allclose(orig_val, lerp_val, rtol=1e-4, atol=1e-3)
else:
torch._foreach_mul_(averaged_model_parameters, decay)
torch._foreach_add_(
averaged_model_parameters, model_parameters, alpha=1 - decay
......@@ -157,6 +183,8 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model()
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
# Whether to use LERP to compute EMA
_C.MODEL_EMA.USE_LERP = False
def _remove_ddp(model):
......@@ -239,7 +267,10 @@ class EMAHook(HookBase):
self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.ema_updater = EMAUpdater(
self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device
self.model.ema_state,
decay=cfg.MODEL_EMA.DECAY,
device=self.device,
use_lerp=cfg.MODEL_EMA.USE_LERP,
)
def before_train(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment