"vscode:/vscode.git/clone" did not exist on "6db661300f472ee8852882af2c3b8b182a403ed1"
Commit 40e12432 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327459481
parent 30821184
...@@ -42,7 +42,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -42,7 +42,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
train: bool = True, train: bool = True,
evaluate: bool = True, evaluate: bool = True,
model=None, model=None,
optimizer=None): optimizer=None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
...@@ -56,6 +57,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -56,6 +57,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
building model using task.build_model(). Default to None. building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None. used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
...@@ -73,6 +76,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -73,6 +76,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else: else:
self._optimizer = optimizer self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
# Configuring optimizer when loss_scale is set in runtime config. This helps # Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations. # avoiding overflow/underflow for float16 computations.
if config.runtime.loss_scale: if config.runtime.loss_scale:
...@@ -235,6 +240,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -235,6 +240,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
if aggregated_logs: if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(aggregated_logs) metrics = self.task.reduce_aggregated_logs(aggregated_logs)
logs.update(metrics) logs.update(metrics)
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, logs, self.global_step.numpy())
metric_name = self.config.trainer.best_checkpoint_eval_metric
logs['best_' + metric_name] = self._checkpoint_exporter.best_ckpt_logs[
metric_name]
return logs return logs
def eval_reduce(self, state=None, step_outputs=None): def eval_reduce(self, state=None, step_outputs=None):
......
...@@ -16,12 +16,14 @@ ...@@ -16,12 +16,14 @@
"""Tests for tensorflow_models.core.trainers.trainer.""" """Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib from official.core import base_trainer as trainer_lib
from official.core import train_lib
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.utils.testing import mock_task from official.utils.testing import mock_task
...@@ -105,6 +107,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -105,6 +107,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
@combinations.generate(all_strategy_combinations())
def test_export_best_ckpt(self, distribution):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='acc',
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
task = mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config, task, checkpoint_exporter=ckpt_exporter)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(tf.io.gfile.exists(
os.path.join(model_dir, 'best_ckpt', 'info.json')))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# ============================================================================== # ==============================================================================
"""TFM common training driver library.""" """TFM common training driver library."""
import copy
import json
import os import os
from typing import Any, Mapping, Tuple from typing import Any, Mapping, Tuple
...@@ -28,6 +30,121 @@ from official.core import base_task ...@@ -28,6 +30,121 @@ from official.core import base_task
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self,
export_dir: str,
metric_name: str,
metric_comp: str):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(
checkpoint, self._best_ckpt_logs, global_step)
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items():
eval_logs_ext[name] = str(orbit.utils.get_value(value))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed.
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.rmtree(file_to_remove)
checkpoint.save(self.best_ckpt_path)
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt')
def maybe_create_best_ckpt_exporter(
params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
else:
best_ckpt_exporter = None
logging.info('Not exporting the best checkpoint. '
'data_dir: %s, export_subdir: %s, metric_name: %s',
data_dir, export_subdir, metric_name)
return best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task, task: base_task.Task,
mode: str, mode: str,
...@@ -62,7 +179,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -62,7 +179,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task, task,
model_dir, model_dir,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval) evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
......
...@@ -18,20 +18,32 @@ ...@@ -18,20 +18,32 @@
import json import json
import os import os
import pprint import pprint
from typing import Any
from absl import logging from absl import logging
import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
def create_trainer(params, task, model_dir, train, evaluate): def create_trainer(
params: config_definitions.ExperimentConfig,
task: base_task.Task,
model_dir: str,
train: bool,
evaluate: bool,
checkpoint_exporter: Any = None):
"""Create trainer."""
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
trainer = base_trainer.Trainer(params, task, train=train, evaluate=evaluate) trainer = base_trainer.Trainer(
params, task, train=train, evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
return trainer return trainer
...@@ -122,10 +134,7 @@ def write_summary(summary_writer, global_step, eval_metrics): ...@@ -122,10 +134,7 @@ def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary.""" """Write evaluation metrics to TF summary."""
numeric_dict = {} numeric_dict = {}
for name, value in eval_metrics.items(): for name, value in eval_metrics.items():
if hasattr(value, 'numpy'): numeric_dict[name] = float(orbit.utils.get_value(value))
numeric_dict[name] = value.numpy().astype(float)
else:
numeric_dict[name] = value
with summary_writer.as_default(): with summary_writer.as_default():
for name, value in numeric_dict.items(): for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step) tf.summary.scalar(name, value, step=global_step)
......
...@@ -183,6 +183,17 @@ class TrainerConfig(base_config.Config): ...@@ -183,6 +183,17 @@ class TrainerConfig(base_config.Config):
validation_steps: number of eval steps. If `None`, the entire eval dataset validation_steps: number of eval steps. If `None`, the entire eval dataset
is used. is used.
validation_interval: number of training steps to run between evaluations. validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
""" """
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings. # Orbit settings.
...@@ -201,6 +212,10 @@ class TrainerConfig(base_config.Config): ...@@ -201,6 +212,10 @@ class TrainerConfig(base_config.Config):
train_steps: int = 0 train_steps: int = 0
validation_steps: Optional[int] = None validation_steps: Optional[int] = None
validation_interval: int = 1000 validation_interval: int = 1000
# Best checkpoint export.
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment