Commit 4a09d217 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 373725548
parent e04d2efc
...@@ -210,7 +210,8 @@ class Trainer(_AsyncTrainer): ...@@ -210,7 +210,8 @@ class Trainer(_AsyncTrainer):
self._runtime_options = get_runtime_options(config) self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average. # Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage): if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(self._model) self._optimizer.shadow_copy(self._model)
# global_step increases by 1 after each training iteration. # global_step increases by 1 after each training iteration.
......
...@@ -76,6 +76,8 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -76,6 +76,8 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
self._dynamic_decay = dynamic_decay self._dynamic_decay = dynamic_decay
self._optimizer = optimizer self._optimizer = optimizer
self._track_trackable(self._optimizer, 'base_optimizer') self._track_trackable(self._optimizer, 'base_optimizer')
self._average_weights = None
self._model_weights = None
def shadow_copy(self, model: tf.keras.Model): def shadow_copy(self, model: tf.keras.Model):
"""Creates shadow variables for the given model weights.""" """Creates shadow variables for the given model weights."""
...@@ -89,7 +91,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer): ...@@ -89,7 +91,7 @@ class ExponentialMovingAverage(tf.keras.optimizers.Optimizer):
@property @property
def has_shadow_copy(self): def has_shadow_copy(self):
"""Whether this optimizer has created shadow variables.""" """Whether this optimizer has created shadow variables."""
return self._model_weights is not None return self._model_weights is not None and self._average_weights is not None
def _create_slots(self, var_list): def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment