Commit 25112f5e authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 376032565
parent e42ee600
......@@ -180,11 +180,15 @@ class EMAConfig(BaseOptimizerConfig):
Attributes:
name: 'str', name of the optimizer.
trainable_weights_only: 'bool', if True, only model trainable weights will
be updated. Otherwise, all model weights will be updated. This mainly
affects batch normalization parameters.
average_decay: 'float', average decay value.
start_step: 'int', start step to apply moving average.
dynamic_decay: 'bool', whether to apply dynamic decay or not.
"""
name: str = "ExponentialMovingAverage"
trainable_weights_only: bool = True
average_decay: float = 0.99
start_step: int = 0
dynamic_decay: bool = True
......
......@@ -48,6 +48,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def __init__(self,
optimizer: tf.keras.optimizers.Optimizer,
trainable_weights_only: bool = True,
average_decay: float = 0.99,
start_step: int = 0,
dynamic_decay: bool = True,
......@@ -58,6 +59,9 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
trainable_weights_only: 'bool', if True, only model trainable weights will
be updated. Otherwise, all model weights will be updated. This mainly
affects batch normalization parameters.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
start_step: int. What step to start the moving average.
......@@ -72,6 +76,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
"""
super().__init__(name, **kwargs)
self._average_decay = average_decay
self._trainable_weights_only = trainable_weights_only
self._start_step = tf.constant(start_step, tf.float32)
self._dynamic_decay = dynamic_decay
self._optimizer = optimizer
......@@ -81,12 +86,17 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def shadow_copy(self, model: tf.keras.Model):
"""Creates shadow variables for the given model weights."""
for var in model.weights:
if self._trainable_weights_only:
self._model_weights = model.trainable_variables
else:
self._model_weights = model.variables
for var in self._model_weights:
self.add_slot(var, 'average', initializer='zeros')
self._average_weights = [
self.get_slot(var, 'average') for var in model.weights
self.get_slot(var, 'average') for var in self._model_weights
]
self._model_weights = model.weights
@property
def has_shadow_copy(self):
......
......@@ -80,6 +80,7 @@ trainer:
optimizer_config:
ema:
average_decay: 0.9999
trainable_weights_only: false
learning_rate:
cosine:
decay_steps: 73682
......
......@@ -227,7 +227,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
}
},
'ema': {
'average_decay': 0.9999
'average_decay': 0.9999,
'trainable_weights_only': False,
},
'learning_rate': {
'type': 'cosine',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment