Commit 0a8036ce authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Internal change.

PiperOrigin-RevId: 360482921
parent 8c169ea4
...@@ -795,7 +795,8 @@ def eager_eval_loop( ...@@ -795,7 +795,8 @@ def eager_eval_loop(
eval_dataset, eval_dataset,
use_tpu=False, use_tpu=False,
postprocess_on_cpu=False, postprocess_on_cpu=False,
global_step=None): global_step=None,
):
"""Evaluate the model eagerly on the evaluation dataset. """Evaluate the model eagerly on the evaluation dataset.
This method will compute the evaluation metrics specified in the configs on This method will compute the evaluation metrics specified in the configs on
...@@ -968,11 +969,10 @@ def eager_eval_loop( ...@@ -968,11 +969,10 @@ def eager_eval_loop(
eval_metrics[loss_key] = tf.reduce_mean(loss_metrics[loss_key]) eval_metrics[loss_key] = tf.reduce_mean(loss_metrics[loss_key])
eval_metrics = {str(k): v for k, v in eval_metrics.items()} eval_metrics = {str(k): v for k, v in eval_metrics.items()}
tf.logging.info('Eval metrics at step %d', global_step) tf.logging.info('Eval metrics at step %d', global_step.numpy())
for k in eval_metrics: for k in eval_metrics:
tf.compat.v2.summary.scalar(k, eval_metrics[k], step=global_step) tf.compat.v2.summary.scalar(k, eval_metrics[k], step=global_step)
tf.logging.info('\t+ %s: %f', k, eval_metrics[k]) tf.logging.info('\t+ %s: %f', k, eval_metrics[k])
return eval_metrics return eval_metrics
...@@ -1026,6 +1026,8 @@ def eval_continuously( ...@@ -1026,6 +1026,8 @@ def eval_continuously(
""" """
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[ get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
'get_configs_from_pipeline_file'] 'get_configs_from_pipeline_file']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[ merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
'merge_external_params_with_configs'] 'merge_external_params_with_configs']
...@@ -1043,6 +1045,10 @@ def eval_continuously( ...@@ -1043,6 +1045,10 @@ def eval_continuously(
'Forced number of epochs for all eval validations to be 1.') 'Forced number of epochs for all eval validations to be 1.')
configs = merge_external_params_with_configs( configs = merge_external_params_with_configs(
configs, None, kwargs_dict=kwargs) configs, None, kwargs_dict=kwargs)
if model_dir:
pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, model_dir)
model_config = configs['model'] model_config = configs['model']
train_input_config = configs['train_input_config'] train_input_config = configs['train_input_config']
eval_config = configs['eval_config'] eval_config = configs['eval_config']
...@@ -1109,4 +1115,5 @@ def eval_continuously( ...@@ -1109,4 +1115,5 @@ def eval_continuously(
eval_input, eval_input,
use_tpu=use_tpu, use_tpu=use_tpu,
postprocess_on_cpu=postprocess_on_cpu, postprocess_on_cpu=postprocess_on_cpu,
global_step=global_step) global_step=global_step,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment