"src/vscode:/vscode.git/clone" did not exist on "1fcf279d74f461401b86093c99afd94059a8cf3c"
Commit 1ef32d1a authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[core] Update BestCheckpointExporter to support dictionary with arbitrary...

[core] Update BestCheckpointExporter to support dictionary with arbitrary depth (possibly coming from multitask eval).
Also added best checkpointexporter support to multitask.

PiperOrigin-RevId: 357584596
parent 5486430e
......@@ -27,26 +27,7 @@ from official.core import config_definitions
from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy,
......@@ -83,7 +64,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir))
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
......
......@@ -17,7 +17,7 @@ import copy
import json
import os
import pprint
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional
from absl import logging
import dataclasses
......@@ -32,6 +32,75 @@ from official.core import exp_factory
from official.modeling import hyperparams
def get_leaf_nested_dict(
d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(
d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
......@@ -45,11 +114,12 @@ class BestCheckpointExporter:
Args:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name
self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
......@@ -88,12 +158,10 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs:
raise KeyError('best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'.format(
self._metric_name, old_logs, new_logs))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
old_value = float(orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
......@@ -113,8 +181,8 @@ class BestCheckpointExporter:
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items():
eval_logs_ext[name] = float(orbit.utils.get_value(value))
eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.train_utils."""
import tensorflow as tf
from official.core import train_utils
class TrainUtilsTest(tf.test.TestCase):
def test_get_leaf_nested_dict(self):
d = {'a': {'i': {'x': 5}}}
self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
def test_get_leaf_nested_dict_not_leaf(self):
with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i'])
def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': 5}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_cast_leaf_nested_dict(self):
d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
d = train_utils.cast_leaf_nested_dict(d, int)
self.assertEqual(d['a']['i']['x'], 123)
self.assertEqual(d['b'], 456)
if __name__ == '__main__':
tf.test.main()
......@@ -37,16 +37,10 @@ class MultiTaskConfig(hyperparams.Config):
@dataclasses.dataclass
class MultiEvalExperimentConfig(hyperparams.Config):
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
task: the single-stream training task.
eval_tasks: individual evaluation tasks.
trainer: the trainer configuration.
runtime: the runtime configuration.
"""
task: cfg.TaskConfig = cfg.TaskConfig()
eval_tasks: MultiTaskConfig = MultiTaskConfig()
trainer: cfg.TrainerConfig = cfg.TrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
......@@ -21,6 +21,7 @@ import gin
import orbit
import tensorflow as tf
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
......@@ -29,16 +30,20 @@ from official.modeling.multitask import multitask
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(self,
task: multitask.MultiTask,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None):
def __init__(
self,
task: multitask.MultiTask,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
......@@ -46,6 +51,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._task = task
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
# TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading.
if hasattr(self.model, "checkpoint_items"):
......@@ -168,4 +174,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
metrics = task.reduce_aggregated_logs(outputs)
logs.update(metrics)
results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results
......@@ -20,6 +20,7 @@ import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask
......@@ -73,7 +74,9 @@ def run_experiment_with_multitask_eval(
evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None)
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else:
evaluator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment