Commit d032c02c authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

add options to exclude buffers and frozen parameters in EMA

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/530

Add options to include/exclude model buffers and frozen parameters in EMA state via two new config keys `MODEL_EMA.INCLUDE_FROZEN` and `MODEL_EMA.INCLUDE_BUFFER`

Reviewed By: tglik

Differential Revision: D45129625

fbshipit-source-id: 895ebe7e4f8e15566c3c3bddd852dd98c40a27b1
parent 3639b43c
...@@ -16,12 +16,14 @@ logger = logging.getLogger(__name__) ...@@ -16,12 +16,14 @@ logger = logging.getLogger(__name__)
class EMAState(object): class EMAState(object):
def __init__(self): def __init__(self, include_frozen=True, include_buffer=True):
self.include_frozen = include_frozen
self.include_buffer = include_buffer
self.state = {} self.state = {}
@classmethod @classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""): def FromModel(cls, model: torch.nn.Module, device: str = "", **kwargs):
ret = cls() ret = cls(**kwargs)
ret.save_from(model, device) ret.save_from(model, device)
return ret return ret
...@@ -72,8 +74,12 @@ class EMAState(object): ...@@ -72,8 +74,12 @@ class EMAState(object):
def get_model_state_iterator(self, model): def get_model_state_iterator(self, model):
param_iter = model.named_parameters() param_iter = model.named_parameters()
buffer_iter = model.named_buffers() if not self.include_frozen:
return itertools.chain(param_iter, buffer_iter) param_iter = iter((n, v) for n, v in param_iter if v.requires_grad)
if self.include_buffer:
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
return param_iter
def state_dict(self): def state_dict(self):
return self.state return self.state
...@@ -179,6 +185,10 @@ def add_model_ema_configs(_C): ...@@ -179,6 +185,10 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA = type(_C)() _C.MODEL_EMA = type(_C)()
_C.MODEL_EMA.ENABLED = False _C.MODEL_EMA.ENABLED = False
_C.MODEL_EMA.DECAY = 0.999 _C.MODEL_EMA.DECAY = 0.999
# Whether to include frozen parameters in EMA
_C.MODEL_EMA.INCLUDE_FROZEN = True
# Whether to include model buffers in EMA
_C.MODEL_EMA.INCLUDE_BUFFER = True
# use the same as MODEL.DEVICE when empty # use the same as MODEL.DEVICE when empty
_C.MODEL_EMA.DEVICE = "" _C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model() # When True, loading the ema weight to the model when eval_only=True in build_model()
...@@ -204,7 +214,10 @@ def may_build_model_ema(cfg, model): ...@@ -204,7 +214,10 @@ def may_build_model_ema(cfg, model):
assert not hasattr( assert not hasattr(
model, "ema_state" model, "ema_state"
), "Name `ema_state` is reserved for model ema." ), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState() model.ema_state = EMAState(
include_frozen=cfg.MODEL_EMA.INCLUDE_FROZEN,
include_buffer=cfg.MODEL_EMA.INCLUDE_BUFFER,
)
logger.info("Using Model EMA.") logger.info("Using Model EMA.")
......
...@@ -61,7 +61,17 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -61,7 +61,17 @@ class TestModelingModelEMA(unittest.TestCase):
state = ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
# two for conv (conv.weight, conv.bias), # two for conv (conv.weight, conv.bias),
# five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked) # five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked)
full_state = {
"conv.weight",
"conv.bias",
"bn.weight",
"bn.bias",
"bn.running_mean",
"bn.running_var",
"bn.num_batches_tracked",
}
self.assertEqual(len(state.state), 7) self.assertEqual(len(state.state), 7)
self.assertTrue(set(state.state) == full_state)
for _, val in state.state.items(): for _, val in state.state.items():
self.assertFalse(val.requires_grad) self.assertFalse(val.requires_grad)
...@@ -72,6 +82,25 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -72,6 +82,25 @@ class TestModelingModelEMA(unittest.TestCase):
state.apply_to(model1) state.apply_to(model1)
self.assertTrue(_compare_state_dict(model, model1)) self.assertTrue(_compare_state_dict(model, model1))
# test ema state that excludes buffers and frozen parameters
model.conv.weight.requires_grad = False
state1 = ema.EMAState.FromModel(model, include_frozen=False)
# should exclude frozen parameter: conv.weight
self.assertTrue(full_state - set(state1.state) == {"conv.weight"})
state2 = ema.EMAState.FromModel(model, include_buffer=False)
# should exclude buffers: bn.running_mean, bn.running_var, bn.num_batches_tracked
self.assertTrue(
full_state - set(state2.state)
== {"bn.running_mean", "bn.running_var", "bn.num_batches_tracked"}
)
state3 = ema.EMAState.FromModel(
model, include_frozen=False, include_buffer=False
)
# should exclude frozen param + buffers: conv.weight, bn.running_mean, bn.running_var, bn.num_batches_tracked
self.assertTrue(set(state3.state) == {"conv.bias", "bn.weight", "bn.bias"})
def test_emastate_saveload(self): def test_emastate_saveload(self):
model = TestArch() model = TestArch()
state = ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment