Commit 47849274 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327459481
parent 4bf01a43
......@@ -42,7 +42,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
train: bool = True,
evaluate: bool = True,
model=None,
optimizer=None):
optimizer=None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models.
Args:
......@@ -56,6 +57,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
......@@ -73,6 +76,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else:
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if config.runtime.loss_scale:
......@@ -235,6 +240,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics)
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, logs, self.global_step.numpy())
metric_name = self.config.trainer.best_checkpoint_eval_metric
logs['best_' + metric_name] = self._checkpoint_exporter.best_ckpt_logs[
metric_name]
return logs
def eval_reduce(self, state=None, step_outputs=None):
......
......@@ -16,12 +16,14 @@
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib
from official.core import train_lib
from official.modeling.hyperparams import config_definitions as cfg
from official.utils.testing import mock_task
......@@ -105,6 +107,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
@combinations.generate(all_strategy_combinations())
def test_export_best_ckpt(self, distribution):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='acc',
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
task = mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config, task, checkpoint_exporter=ckpt_exporter)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(tf.io.gfile.exists(
os.path.join(model_dir, 'best_ckpt', 'info.json')))
if __name__ == '__main__':
tf.test.main()
......@@ -15,6 +15,8 @@
# ==============================================================================
"""TFM common training driver library."""
import copy
import json
import os
from typing import Any, Mapping, Tuple
......@@ -28,6 +30,121 @@ from official.core import base_task
from official.modeling.hyperparams import config_definitions
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self,
export_dir: str,
metric_name: str,
metric_comp: str):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(
checkpoint, self._best_ckpt_logs, global_step)
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items():
eval_logs_ext[name] = str(orbit.utils.get_value(value))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.rmtree(file_to_remove)
checkpoint.save(self.best_ckpt_path)
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt')
def maybe_create_best_ckpt_exporter(
params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
else:
best_ckpt_exporter = None
logging.info('Not exporting the best checkpoint. '
'data_dir: %s, export_subdir: %s, metric_name: %s',
data_dir, export_subdir, metric_name)
return best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
......@@ -62,7 +179,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task,
model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval)
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
......
......@@ -18,20 +18,32 @@
import json
import os
import pprint
from typing import Any
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions
def create_trainer(params, task, model_dir, train, evaluate):
def create_trainer(
params: config_definitions.ExperimentConfig,
task: base_task.Task,
model_dir: str,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None):
"""Create trainer."""
del model_dir
logging.info('Running default trainer.')
trainer = base_trainer.Trainer(params, task, train=train, evaluate=evaluate)
trainer = base_trainer.Trainer(
params, task, train=train, evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
return trainer
......@@ -122,10 +134,7 @@ def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
numeric_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
numeric_dict[name] = value.numpy().astype(float)
else:
numeric_dict[name] = value
numeric_dict[name] = float(orbit.utils.get_value(value))
with summary_writer.as_default():
for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step)
......
......@@ -183,6 +183,17 @@ class TrainerConfig(base_config.Config):
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
......@@ -201,6 +212,10 @@ class TrainerConfig(base_config.Config):
train_steps: int = 0
validation_steps: Optional[int] = None
validation_interval: int = 1000
# Best checkpoint export.
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment