".github/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0"
Commit d032c02c authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

add options to exclude buffers and frozen parameters in EMA

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/530

Add options to include/exclude model buffers and frozen parameters in EMA state via two new config keys `MODEL_EMA.INCLUDE_FROZEN` and `MODEL_EMA.INCLUDE_BUFFER`

Reviewed By: tglik

Differential Revision: D45129625

fbshipit-source-id: 895ebe7e4f8e15566c3c3bddd852dd98c40a27b1
parent 3639b43c
...@@ -16,12 +16,14 @@ logger = logging.getLogger(__name__) ...@@ -16,12 +16,14 @@ logger = logging.getLogger(__name__)
class EMAState(object): class EMAState(object):
def __init__(self): def __init__(self, include_frozen=True, include_buffer=True):
self.include_frozen = include_frozen
self.include_buffer = include_buffer
self.state = {} self.state = {}
@classmethod @classmethod
def FromModel(cls, model: torch.nn.Module, device: str = ""): def FromModel(cls, model: torch.nn.Module, device: str = "", **kwargs):
ret = cls() ret = cls(**kwargs)
ret.save_from(model, device) ret.save_from(model, device)
return ret return ret
...@@ -72,8 +74,12 @@ class EMAState(object): ...@@ -72,8 +74,12 @@ class EMAState(object):
def get_model_state_iterator(self, model): def get_model_state_iterator(self, model):
param_iter = model.named_parameters() param_iter = model.named_parameters()
buffer_iter = model.named_buffers() if not self.include_frozen:
return itertools.chain(param_iter, buffer_iter) param_iter = iter((n, v) for n, v in param_iter if v.requires_grad)
if self.include_buffer:
buffer_iter = model.named_buffers()
return itertools.chain(param_iter, buffer_iter)
return param_iter
def state_dict(self): def state_dict(self):
return self.state return self.state
...@@ -179,6 +185,10 @@ def add_model_ema_configs(_C): ...@@ -179,6 +185,10 @@ def add_model_ema_configs(_C):
_C.MODEL_EMA = type(_C)() _C.MODEL_EMA = type(_C)()
_C.MODEL_EMA.ENABLED = False _C.MODEL_EMA.ENABLED = False
_C.MODEL_EMA.DECAY = 0.999 _C.MODEL_EMA.DECAY = 0.999
# Whether to include frozen parameters in EMA
_C.MODEL_EMA.INCLUDE_FROZEN = True
# Whether to include model buffers in EMA
_C.MODEL_EMA.INCLUDE_BUFFER = True
# use the same as MODEL.DEVICE when empty # use the same as MODEL.DEVICE when empty
_C.MODEL_EMA.DEVICE = "" _C.MODEL_EMA.DEVICE = ""
# When True, loading the ema weight to the model when eval_only=True in build_model() # When True, loading the ema weight to the model when eval_only=True in build_model()
...@@ -204,7 +214,10 @@ def may_build_model_ema(cfg, model): ...@@ -204,7 +214,10 @@ def may_build_model_ema(cfg, model):
assert not hasattr( assert not hasattr(
model, "ema_state" model, "ema_state"
), "Name `ema_state` is reserved for model ema." ), "Name `ema_state` is reserved for model ema."
model.ema_state = EMAState() model.ema_state = EMAState(
include_frozen=cfg.MODEL_EMA.INCLUDE_FROZEN,
include_buffer=cfg.MODEL_EMA.INCLUDE_BUFFER,
)
logger.info("Using Model EMA.") logger.info("Using Model EMA.")
......
...@@ -61,7 +61,17 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -61,7 +61,17 @@ class TestModelingModelEMA(unittest.TestCase):
state = ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
# two for conv (conv.weight, conv.bias), # two for conv (conv.weight, conv.bias),
# five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked) # five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked)
full_state = {
"conv.weight",
"conv.bias",
"bn.weight",
"bn.bias",
"bn.running_mean",
"bn.running_var",
"bn.num_batches_tracked",
}
self.assertEqual(len(state.state), 7) self.assertEqual(len(state.state), 7)
self.assertTrue(set(state.state) == full_state)
for _, val in state.state.items(): for _, val in state.state.items():
self.assertFalse(val.requires_grad) self.assertFalse(val.requires_grad)
...@@ -72,6 +82,25 @@ class TestModelingModelEMA(unittest.TestCase): ...@@ -72,6 +82,25 @@ class TestModelingModelEMA(unittest.TestCase):
state.apply_to(model1) state.apply_to(model1)
self.assertTrue(_compare_state_dict(model, model1)) self.assertTrue(_compare_state_dict(model, model1))
# test ema state that excludes buffers and frozen parameters
model.conv.weight.requires_grad = False
state1 = ema.EMAState.FromModel(model, include_frozen=False)
# should exclude frozen parameter: conv.weight
self.assertTrue(full_state - set(state1.state) == {"conv.weight"})
state2 = ema.EMAState.FromModel(model, include_buffer=False)
# should exclude buffers: bn.running_mean, bn.running_var, bn.num_batches_tracked
self.assertTrue(
full_state - set(state2.state)
== {"bn.running_mean", "bn.running_var", "bn.num_batches_tracked"}
)
state3 = ema.EMAState.FromModel(
model, include_frozen=False, include_buffer=False
)
# should exclude frozen param + buffers: conv.weight, bn.running_mean, bn.running_var, bn.num_batches_tracked
self.assertTrue(set(state3.state) == {"conv.bias", "bn.weight", "bn.bias"})
def test_emastate_saveload(self): def test_emastate_saveload(self):
model = TestArch() model = TestArch()
state = ema.EMAState.FromModel(model) state = ema.EMAState.FromModel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment