Commit 1c9e0e83 authored by Ji Hou's avatar Ji Hou Committed by Facebook GitHub Bot
Browse files

add warm up stage for d2go ema (for fsdp)

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/602

per title

Reviewed By: wat3rBro

Differential Revision: D47740831

fbshipit-source-id: ecbe48a1085232a5cfb696e7f8e537d7e58e534a
parent 0940b814
......@@ -131,6 +131,7 @@ class EMAUpdater(object):
decay: float = 0.999,
device: str = "",
use_lerp: bool = False,
decay_warm_up_factor: int = -1,
):
self.decay = decay
self.device = device
......@@ -139,11 +140,27 @@ class EMAUpdater(object):
self.use_lerp = use_lerp
self.debug_lerp = False
self._num_updates: int = -1
self.decay_warm_up_factor = decay_warm_up_factor
if self.decay_warm_up_factor >= 0:
self._num_updates = 0
def init_state(self, model):
self.state.clear()
self.state.save_from(model, self.device)
def update(self, model):
# compute decay
decay = self.decay
if self._num_updates >= 0:
self._num_updates += 1
decay = min(
self.decay,
(1 + self._num_updates)
/ (self.decay_warm_up_factor + self._num_updates),
)
# update moving average
with torch.no_grad():
ema_param_list = []
param_list = []
......@@ -155,8 +172,8 @@ class EMAUpdater(object):
ema_param_list.append(ema_val)
param_list.append(val)
else:
ema_val.copy_(ema_val * self.decay + val * (1.0 - self.decay))
self._ema_avg(ema_param_list, param_list, self.decay)
ema_val.copy_(ema_val * decay + val * (1.0 - decay))
self._ema_avg(ema_param_list, param_list, decay)
def _ema_avg(
self,
......@@ -208,6 +225,8 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA.USE_LERP = False
# Whether to put EMA to the backward pass
_C.MODEL_EMA.AFTER_BACKWARD = False
# Whether to warmup the EMA update process
_C.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
def _remove_ddp(model):
......@@ -308,6 +327,7 @@ class EMAHook(HookBase):
decay=cfg.MODEL_EMA.DECAY,
device=self.device,
use_lerp=cfg.MODEL_EMA.USE_LERP,
decay_warm_up_factor=cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR,
)
def before_train(self):
......
......@@ -185,6 +185,7 @@ class TestModelingModelEMAHook(unittest.TestCase):
cfg.MODEL_EMA.ENABLED = True
# use new model weights
cfg.MODEL_EMA.DECAY = 0.0
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
model = TestArch()
ema.may_build_model_ema(cfg, model)
......
......@@ -182,6 +182,7 @@ class TestDefaultRunner(unittest.TestCase):
cfg.MODEL.META_ARCHITECTURE = "MetaArchForTestSingleValue"
cfg.MODEL_EMA.ENABLED = True
cfg.MODEL_EMA.DECAY = 0.9
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
def _run_train(cfg):
cfg = copy.deepcopy(cfg)
......
......@@ -118,6 +118,7 @@ class TestActivationCheckpointing(unittest.TestCase):
cfg.MODEL.MODELING_HOOKS = ["ActivationCheckpointModelingHook"]
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY = "layer_based_auto_wrap_policy"
cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_LAYER_CLS = ["Conv2d", "BatchNorm2d"]
cfg.MODEL_EMA.DECAY_WARM_UP_FACTOR = -1
model = runner.build_model(cfg)
runner.do_train(cfg, model, resume=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment