"torchvision/models/vscode:/vscode.git/clone" did not exist on "f16b67234260f0a32f6106313ad42e102ada6fa0"
Commit dbf629c4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368122127
parent 0d8f9807
......@@ -32,8 +32,7 @@ from official.core import exp_factory
from official.modeling import hyperparams
def get_leaf_nested_dict(
d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
......@@ -60,9 +59,8 @@ def get_leaf_nested_dict(
return leaf
def cast_leaf_nested_dict(
d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
def cast_leaf_nested_dict(d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
......@@ -88,8 +86,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
......@@ -130,8 +128,8 @@ class BestCheckpointExporter:
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (
self._checkpoint_manager.checkpoint != checkpoint):
if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
!= checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
......@@ -158,10 +156,12 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
old_value = float(orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
old_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
......@@ -254,8 +254,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
'tpu': flags_obj.tpu,
},
})
if flags_obj.tf_data_service and isinstance(params.task,
config_definitions.TaskConfig):
if ('tf_data_service' in flags_obj and flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)):
params.override({
'task': {
'train_data': {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment