Commit 1c9e0e83 authored by Ji Hou's avatar Ji Hou Committed by Facebook GitHub Bot
Browse files

add warm up stage for d2go ema (for fsdp)

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/602

per title

Reviewed By: wat3rBro

Differential Revision: D47740831

fbshipit-source-id: ecbe48a1085232a5cfb696e7f8e537d7e58e534a
parent 0940b814
...@@ -131,6 +131,7 @@ class EMAUpdater(object): ...@@ -131,6 +131,7 @@ class EMAUpdater(object):
decay: float = 0.999, decay: float = 0.999,
device: str = "", device: str = "",
use_lerp: bool = False, use_lerp: bool = False,
decay_warm_up_factor: int = -1,
): ):
self.decay = decay self.decay = decay
self.device = device self.device = device
...@@ -139,11 +140,27 @@ class EMAUpdater(object): ...@@ -139,11 +140,27 @@ class EMAUpdater(object):
self.use_lerp = use_lerp self.use_lerp = use_lerp
self.debug_lerp = False self.debug_lerp = False
self._num_updates: int = -1
self.decay_warm_up_factor = decay_warm_up_factor
if self.decay_warm_up_factor >= 0:
self._num_updates = 0
def init_state(self, model): def init_state(self, model):
self.state.clear() self.state.clear()
self.state.save_from(model, self.device) self.state.save_from(model, self.device)
def update(self, model): def update(self, model):
# compute decay
decay = self.decay
if self._num_updates >= 0:
self._num_updates += 1
decay = min(
self.decay,
(1 + self._num_updates)
/ (self.decay_warm_up_factor + self._num_updates),
)
# update moving average
with torch.no_grad(): with torch.no_grad():
ema_param_list = [] ema_param_list = []
param_list = [] param_list = []
...@@ -155,8 +172,8 @@ class EMAUpdater(object): ...@@ -155,8 +172,8 @@ class EMAUpdater(object):
ema_param_list.append(ema_val) ema_param_list.append(ema_val)
param_list.append(val) param_list.append(val)
else: else:
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay)) ema_val.copy_(ema_val * decay + val * (1.0 - decay))
self._ema_avg(ema_param_list, param_list, self.decay) self._ema_avg(ema_param_list, param_list, decay)
def _ema_avg( def _ema_avg(
self, self,
...@@ -208,6 +225,8 @@ def add_model_ema_configs(_C): ...@@ -208,6 +225,8 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA.USE_LERP = False _C.MODEL_EMA.USE_LERP = False
# Whether to put EMA to the backward pass # Whether to put EMA to the backward pass
_C.MODEL_EMA.AFTER_BACKWARD = False _C.MODEL_EMA.AFTER_BACKWARD = False
# Whether to warmup the EMA update process
_C.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
def _remove_ddp(model): def _remove_ddp(model):
...@@ -308,6 +327,7 @@ class EMAHook(HookBase): ...@@ -308,6 +327,7 @@ class EMAHook(HookBase):
decay=cfg.MODEL_EMA.DECAY, decay=cfg.MODEL_EMA.DECAY,
device=self.device, device=self.device,
use_lerp=cfg.MODEL_EMA.USE_LERP, use_lerp=cfg.MODEL_EMA.USE_LERP,
decay_warm_up_factor=cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR,
) )
def before_train(self): def before_train(self):
......
...@@ -185,6 +185,7 @@ class TestModelingModelEMAHook(unittest.TestCase): ...@@ -185,6 +185,7 @@ class TestModelingModelEMAHook(unittest.TestCase):
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
# use new model weights # use new model weights
cfg.MODEL_EMA.DECAY = 0.0 cfg.MODEL_EMA.DECAY = 0.0
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
model = TestArch() model = TestArch()
ema.may_build_model_ema(cfg, model) ema.may_build_model_ema(cfg, model)
......
...@@ -182,6 +182,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -182,6 +182,7 @@ class TestDefaultRunner(unittest.TestCase):
cfg.MODEL.META_ARCHITECTURE = "MetaArchForTestSingleValue" cfg.MODEL.META_ARCHITECTURE = "MetaArchForTestSingleValue"
cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.ENABLED = True
cfg.MODEL_EMA.DECAY = 0.9 cfg.MODEL_EMA.DECAY = 0.9
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
def _run_train(cfg): def _run_train(cfg):
cfg = copy.deepcopy(cfg) cfg = copy.deepcopy(cfg)
......
...@@ -118,6 +118,7 @@ class TestActivationCheckpointing(unittest.TestCase): ...@@ -118,6 +118,7 @@ class TestActivationCheckpointing(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["ActivationCheckpointModelingHook"] cfg.MODEL.MODELING_HOOKS = ["ActivationCheckpointModelingHook"]
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "layer_based_auto_wrap_policy" cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "layer_based_auto_wrap_policy"
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = ["Conv2d", "BatchNorm2d"] cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = ["Conv2d", "BatchNorm2d"]
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
model = runner.build_model(cfg) model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False) runner.do_train(cfg, model, resume=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment