Commit 30f93777 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Adding option to save the pipeline config file during evaluation.

PiperOrigin-RevId: 362353183
parent af2567bd
...@@ -506,6 +506,8 @@ def train_loop( ...@@ -506,6 +506,8 @@ def train_loop(
# Write the as-run pipeline config to disk. # Write the as-run pipeline config to disk.
if save_final_config: if save_final_config:
tf.logging.info('Saving pipeline config file to directory {}'.format(
model_dir))
pipeline_config_final = create_pipeline_proto_from_configs(configs) pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, model_dir) config_util.save_pipeline_config(pipeline_config_final, model_dir)
...@@ -991,6 +993,7 @@ def eval_continuously( ...@@ -991,6 +993,7 @@ def eval_continuously(
wait_interval=180, wait_interval=180,
timeout=3600, timeout=3600,
eval_index=0, eval_index=0,
save_final_config=False,
**kwargs): **kwargs):
"""Run continuous evaluation of a detection model eagerly. """Run continuous evaluation of a detection model eagerly.
...@@ -1022,7 +1025,8 @@ def eval_continuously( ...@@ -1022,7 +1025,8 @@ def eval_continuously(
will terminate if no new checkpoints are found after these many seconds. will terminate if no new checkpoints are found after these many seconds.
eval_index: int, If given, only evaluate the dataset at the given eval_index: int, If given, only evaluate the dataset at the given
index. By default, evaluates dataset at 0'th index. index. By default, evaluates dataset at 0'th index.
save_final_config: Whether to save the pipeline config file to the model
directory.
**kwargs: Additional keyword arguments for configuration override. **kwargs: Additional keyword arguments for configuration override.
""" """
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
...@@ -1046,7 +1050,9 @@ def eval_continuously( ...@@ -1046,7 +1050,9 @@ def eval_continuously(
'Forced number of epochs for all eval validations to be 1.') 'Forced number of epochs for all eval validations to be 1.')
configs = merge_external_params_with_configs( configs = merge_external_params_with_configs(
configs, None, kwargs_dict=kwargs) configs, None, kwargs_dict=kwargs)
if model_dir: if model_dir and save_final_config:
tf.logging.info('Saving pipeline config file to directory {}'.format(
model_dir))
pipeline_config_final = create_pipeline_proto_from_configs(configs) pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, model_dir) config_util.save_pipeline_config(pipeline_config_final, model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment