Commit 25112f5e authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 376032565
parent e42ee600
...@@ -180,11 +180,15 @@ class EMAConfig(BaseOptimizerConfig): ...@@ -180,11 +180,15 @@ class EMAConfig(BaseOptimizerConfig):
Attributes: Attributes:
name: 'str', name of the optimizer. name: 'str', name of the optimizer.
trainable_weights_only: 'bool', if True, only model trainable weights will
be updated. Otherwise, all model weights will be updated. This mainly
affects batch normalization parameters.
average_decay: 'float', average decay value. average_decay: 'float', average decay value.
start_step: 'int', start step to apply moving average. start_step: 'int', start step to apply moving average.
dynamic_decay: 'bool', whether to apply dynamic decay or not. dynamic_decay: 'bool', whether to apply dynamic decay or not.
""" """
name: str = "ExponentialMovingAverage" name: str = "ExponentialMovingAverage"
trainable_weights_only: bool = True
average_decay: float = 0.99 average_decay: float = 0.99
start_step: int = 0 start_step: int = 0
dynamic_decay: bool = True dynamic_decay: bool = True
......
...@@ -48,6 +48,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -48,6 +48,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def __init__(self, def __init__(self,
optimizer: tf.keras.optimizers.Optimizer, optimizer: tf.keras.optimizers.Optimizer,
trainable_weights_only: bool = True,
average_decay: float = 0.99, average_decay: float = 0.99,
start_step: int = 0, start_step: int = 0,
dynamic_decay: bool = True, dynamic_decay: bool = True,
...@@ -58,6 +59,9 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -58,6 +59,9 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
Args: Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients. used to compute and apply gradients.
trainable_weights_only: 'bool', if True, only model trainable weights will
be updated. Otherwise, all model weights will be updated. This mainly
affects batch normalization parameters.
average_decay: float. Decay to use to maintain the moving averages average_decay: float. Decay to use to maintain the moving averages
of trained variables. of trained variables.
start_step: int. What step to start the moving average. start_step: int. What step to start the moving average.
...@@ -72,6 +76,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -72,6 +76,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
""" """
super().__init__(name, **kwargs) super().__init__(name, **kwargs)
self._average_decay = average_decay self._average_decay = average_decay
self._trainable_weights_only = trainable_weights_only
self._start_step = tf.constant(start_step, tf.float32) self._start_step = tf.constant(start_step, tf.float32)
self._dynamic_decay = dynamic_decay self._dynamic_decay = dynamic_decay
self._optimizer = optimizer self._optimizer = optimizer
...@@ -81,12 +86,17 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -81,12 +86,17 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
def shadow_copy(self, model: tf.keras.Model): def shadow_copy(self, model: tf.keras.Model):
"""Creates shadow variables for the given model weights.""" """Creates shadow variables for the given model weights."""
for var in model.weights:
if self._trainable_weights_only:
self._model_weights = model.trainable_variables
else:
self._model_weights = model.variables
for var in self._model_weights:
self.add_slot(var, 'average', initializer='zeros') self.add_slot(var, 'average', initializer='zeros')
self._average_weights = [ self._average_weights = [
self.get_slot(var, 'average') for var in model.weights self.get_slot(var, 'average') for var in self._model_weights
] ]
self._model_weights = model.weights
@property @property
def has_shadow_copy(self): def has_shadow_copy(self):
......
...@@ -80,6 +80,7 @@ trainer: ...@@ -80,6 +80,7 @@ trainer:
optimizer_config: optimizer_config:
ema: ema:
average_decay: 0.9999 average_decay: 0.9999
trainable_weights_only: false
learning_rate: learning_rate:
cosine: cosine:
decay_steps: 73682 decay_steps: 73682
......
...@@ -227,7 +227,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig: ...@@ -227,7 +227,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
} }
}, },
'ema': { 'ema': {
'average_decay': 0.9999 'average_decay': 0.9999,
'trainable_weights_only': False,
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment