Commit dbf629c4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368122127
parent 0d8f9807
...@@ -32,8 +32,7 @@ from official.core import exp_factory ...@@ -32,8 +32,7 @@ from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
def get_leaf_nested_dict( def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys. """Get leaf from a dictionary with arbitrary depth with a list of keys.
Args: Args:
...@@ -60,9 +59,8 @@ def get_leaf_nested_dict( ...@@ -60,9 +59,8 @@ def get_leaf_nested_dict(
return leaf return leaf
def cast_leaf_nested_dict( def cast_leaf_nested_dict(d: Dict[str, Any],
d: Dict[str, Any], cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place. """Cast the leaves of a dictionary with arbitrary depth in place.
Args: Args:
...@@ -88,8 +86,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig, ...@@ -88,8 +86,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
metric_comp = params.trainer.best_checkpoint_metric_comp metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name: if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir) best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter( best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
best_ckpt_dir, metric_name, metric_comp) metric_comp)
logging.info( logging.info(
'Created the best checkpoint exporter. ' 'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir, 'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
...@@ -130,8 +128,8 @@ class BestCheckpointExporter: ...@@ -130,8 +128,8 @@ class BestCheckpointExporter:
def _get_checkpoint_manager(self, checkpoint): def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one.""" """Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or ( if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
self._checkpoint_manager.checkpoint != checkpoint): != checkpoint):
logging.info('Creates a new checkpoint manager.') logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager( self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint, checkpoint,
...@@ -158,10 +156,12 @@ class BestCheckpointExporter: ...@@ -158,10 +156,12 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs): def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs.""" """Check if the metric in new_logs is better than the metric in old_logs."""
old_value = float(orbit.utils.get_value( old_value = float(
get_leaf_nested_dict(old_logs, self._metric_name))) orbit.utils.get_value(
new_value = float(orbit.utils.get_value( get_leaf_nested_dict(old_logs, self._metric_name)))
get_leaf_nested_dict(new_logs, self._metric_name))) new_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f', logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value) old_value, new_value)
...@@ -254,8 +254,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -254,8 +254,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
'tpu': flags_obj.tpu, 'tpu': flags_obj.tpu,
}, },
}) })
if flags_obj.tf_data_service and isinstance(params.task, if ('tf_data_service' in flags_obj and flags_obj.tf_data_service and
config_definitions.TaskConfig): isinstance(params.task, config_definitions.TaskConfig)):
params.override({ params.override({
'task': { 'task': {
'train_data': { 'train_data': {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment