Commit a7dc757c authored by Fei Sun's avatar Fei Sun Committed by Facebook GitHub Bot
Browse files

Move EMA to after backward.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/494

Currently EMA computation is in the after step hook. It is in the critical path where no other work is available. This increases the training iteration time. This diff moves the EMA computation to after the backward but before the optimizer step. This way, the majority of the EMA computation time on the CPU can be hidden since CPU at that time is waiting for the GPU to finish the backward anyway. This change may completely hide the EMA CPU time. It reduces the EMA time from 20ms to 4ms, where the 4ms is the GPU time.

However, with this change, the EMA gets its value from the previous iteration value (since it is before step). but since we do many epochs of training, one iteration difference may not be significant.

Reviewed By: tglik

Differential Revision: D43527552

fbshipit-source-id: 1faa9d910b20cae0fc77da541bc0ad176bce18a8
parent 5f1ef548
...@@ -185,6 +185,8 @@ def add_model_ema_configs(_C): ...@@ -185,6 +185,8 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False _C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False
# Whether to use LERP to compute EMA # Whether to use LERP to compute EMA
_C.MODEL_EMA.USE_LERP = False _C.MODEL_EMA.USE_LERP = False
# Whether to put EMA to the backward pass
_C.MODEL_EMA.AFTER_BACKWARD = False
def _remove_ddp(model): def _remove_ddp(model):
...@@ -266,6 +268,7 @@ class EMAHook(HookBase): ...@@ -266,6 +268,7 @@ class EMAHook(HookBase):
self.model = model self.model = model
self.ema = self.model.ema_state self.ema = self.model.ema_state
self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE
self.is_after_backward = cfg.MODEL_EMA.AFTER_BACKWARD
self.ema_updater = EMAUpdater( self.ema_updater = EMAUpdater(
self.model.ema_state, self.model.ema_state,
decay=cfg.MODEL_EMA.DECAY, decay=cfg.MODEL_EMA.DECAY,
...@@ -285,7 +288,17 @@ class EMAHook(HookBase): ...@@ -285,7 +288,17 @@ class EMAHook(HookBase):
def before_step(self): def before_step(self):
pass pass
def after_backward(self):
if not self.is_after_backward:
return
self._update()
def after_step(self): def after_step(self):
if self.is_after_backward:
return
self._update()
def _update(self):
if not self.model.train: if not self.model.train:
return return
self.ema_updater.update(self.model) self.ema_updater.update(self.model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment