Commit 98839bd2 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Support EMA

- Create shadow weights at the beginning of training.
- Swap the weights during.
- Save best checkpoint with average weights.

The following fields need to be set in order to activate the best checkpoint exporter.
best_checkpoint_eval_metric
best_checkpoint_export_subdir
best_checkpoint_metric_comp

To serve, or to finetune the trained checkpoints on a target dataset, use checkpoints under best_checkpoint_export_subdir.

PiperOrigin-RevId: 356093831
parent 99b8390c
...@@ -26,6 +26,8 @@ import tensorflow as tf ...@@ -26,6 +26,8 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.modeling import optimization
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig TrainerConfig = config_definitions.TrainerConfig
...@@ -119,6 +121,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -119,6 +121,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._recovery = None self._recovery = None
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage):
self._optimizer.shadow_copy(self._model)
# global_step increases by 1 after each training iteration. # global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy() # We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer. # when there is only 1 optimizer.
...@@ -209,7 +215,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -209,7 +215,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
@property @property
def optimizer(self): def optimizer(self):
return self._optimizer if hasattr(self, "_optimizer"):
return self._optimizer
else:
return None
@property @property
def global_step(self): def global_step(self):
...@@ -294,6 +303,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -294,6 +303,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Sets up metrics.""" """Sets up metrics."""
for metric in self.validation_metrics + [self.validation_loss]: for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states() metric.reset_states()
# Swaps weights to test on weights moving average.
if self.optimizer and isinstance(
self.optimizer, optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
def eval_step(self, iterator): def eval_step(self, iterator):
"""See base class.""" """See base class."""
...@@ -331,6 +344,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -331,6 +344,12 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
logs["best_" + logs["best_" +
metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name] metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if self.optimizer and isinstance(
self.optimizer, optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
return logs return logs
def eval_reduce(self, state=None, step_outputs=None): def eval_reduce(self, state=None, step_outputs=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment