Commit 84a561ca authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Update BestCheckpointExporter to support dictionary with arbitrary...

[core] Update BestCheckpointExporter to support dictionary with arbitrary depth (possibly coming from multitask eval).
Also added best checkpointexporter support to multitask.

PiperOrigin-RevId: 357584596
parent 9da3a081
...@@ -27,26 +27,7 @@ from official.core import config_definitions ...@@ -27,26 +27,7 @@ from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(distribution_strategy: tf.distribute.Strategy,
...@@ -83,7 +64,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -83,7 +64,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task, task,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
......
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import json import json
import os import os
import pprint import pprint
from typing import List, Optional from typing import Any, Callable, Dict, List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -32,6 +32,75 @@ from official.core import exp_factory ...@@ -32,6 +32,75 @@ from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
def get_leaf_nested_dict(
d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(
d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class BestCheckpointExporter: class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint. """Keeps track of the best result, and saves its checkpoint.
...@@ -45,11 +114,12 @@ class BestCheckpointExporter: ...@@ -45,11 +114,12 @@ class BestCheckpointExporter:
Args: Args:
export_dir: The directory that will contain exported checkpoints. export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which metric_name: Indicates which metric to look at, when determining which
result is better. result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`. metric_comp: Indicates how to compare results. Either `lower` or `higher`.
""" """
self._export_dir = export_dir self._export_dir = export_dir
self._metric_name = metric_name self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'): if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of ' raise ValueError('best checkpoint metric comp must be one of '
...@@ -88,12 +158,10 @@ class BestCheckpointExporter: ...@@ -88,12 +158,10 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs): def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs.""" """Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs: old_value = float(orbit.utils.get_value(
raise KeyError('best checkpoint eval metric name {} is not valid. ' get_leaf_nested_dict(old_logs, self._metric_name)))
'old_logs: {}, new_logs: {}'.format( new_value = float(orbit.utils.get_value(
self._metric_name, old_logs, new_logs)) get_leaf_nested_dict(new_logs, self._metric_name)))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f', logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value) old_value, new_value)
...@@ -113,8 +181,8 @@ class BestCheckpointExporter: ...@@ -113,8 +181,8 @@ class BestCheckpointExporter:
"""Export evaluation results of the best checkpoint into a json file.""" """Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs) eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items(): eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext[name] = float(orbit.utils.get_value(value)) eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast. # Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.train_utils."""
import tensorflow as tf
from official.core import train_utils
class TrainUtilsTest(tf.test.TestCase):
def test_get_leaf_nested_dict(self):
d = {'a': {'i': {'x': 5}}}
self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
def test_get_leaf_nested_dict_not_leaf(self):
with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i'])
def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': 5}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_cast_leaf_nested_dict(self):
d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
d = train_utils.cast_leaf_nested_dict(d, int)
self.assertEqual(d['a']['i']['x'], 123)
self.assertEqual(d['b'], 456)
if __name__ == '__main__':
tf.test.main()
...@@ -37,16 +37,10 @@ class MultiTaskConfig(hyperparams.Config): ...@@ -37,16 +37,10 @@ class MultiTaskConfig(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class MultiEvalExperimentConfig(hyperparams.Config): class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation. """An experiment config for single-task training and multi-task evaluation.
Attributes: Attributes:
task: the single-stream training task.
eval_tasks: individual evaluation tasks. eval_tasks: individual evaluation tasks.
trainer: the trainer configuration.
runtime: the runtime configuration.
""" """
task: cfg.TaskConfig = cfg.TaskConfig()
eval_tasks: MultiTaskConfig = MultiTaskConfig() eval_tasks: MultiTaskConfig = MultiTaskConfig()
trainer: cfg.TrainerConfig = cfg.TrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
...@@ -21,6 +21,7 @@ import gin ...@@ -21,6 +21,7 @@ import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import train_utils
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
...@@ -29,16 +30,20 @@ from official.modeling.multitask import multitask ...@@ -29,16 +30,20 @@ from official.modeling.multitask import multitask
class MultiTaskEvaluator(orbit.AbstractEvaluator): class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
def __init__(self, def __init__(
task: multitask.MultiTask, self,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel], task: multitask.MultiTask,
global_step: Optional[tf.Variable] = None): model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
task: A multitask.MultiTask instance. task: A multitask.MultiTask instance.
model: tf.keras.Model instance. model: tf.keras.Model instance.
global_step: the global step variable. global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
...@@ -46,6 +51,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -46,6 +51,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._task = task self._task = task
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
# TODO(hongkuny): Define a more robust way to handle the training/eval # TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading. # checkpoint loading.
if hasattr(self.model, "checkpoint_items"): if hasattr(self.model, "checkpoint_items"):
...@@ -168,4 +174,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -168,4 +174,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
metrics = task.reduce_aggregated_logs(outputs) metrics = task.reduce_aggregated_logs(outputs)
logs.update(metrics) logs.update(metrics)
results[name] = logs results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results return results
...@@ -20,6 +20,7 @@ import orbit ...@@ -20,6 +20,7 @@ import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer as core_lib from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
...@@ -73,7 +74,9 @@ def run_experiment_with_multitask_eval( ...@@ -73,7 +74,9 @@ def run_experiment_with_multitask_eval(
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks, task=eval_tasks,
model=model, model=model,
global_step=trainer.global_step if is_training else None) global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else: else:
evaluator = None evaluator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment