Commit fa15ed1e authored by Soroosh Yazdani's avatar Soroosh Yazdani Committed by TF Object Detection Team
Browse files

Adding the option of continuous eval with yield, to allow metrics to be updated and logged.

PiperOrigin-RevId: 352851428
parent c787baad
...@@ -971,12 +971,12 @@ def _evaluate_checkpoint(estimator, ...@@ -971,12 +971,12 @@ def _evaluate_checkpoint(estimator,
raise e raise e
def continuous_eval(estimator, def continuous_eval_generator(estimator,
model_dir, model_dir,
input_fn, input_fn,
train_steps, train_steps,
name, name,
max_retries=0): max_retries=0):
"""Perform continuous evaluation on checkpoints written to a model directory. """Perform continuous evaluation on checkpoints written to a model directory.
Args: Args:
...@@ -989,6 +989,9 @@ def continuous_eval(estimator, ...@@ -989,6 +989,9 @@ def continuous_eval(estimator,
max_retries: Maximum number of times to retry the evaluation on encountering max_retries: Maximum number of times to retry the evaluation on encountering
a tf.errors.InvalidArgumentError. If negative, will always retry the a tf.errors.InvalidArgumentError. If negative, will always retry the
evaluation. evaluation.
Yields:
Pair of current step and eval_results.
""" """
def terminate_eval(): def terminate_eval():
...@@ -1011,6 +1014,7 @@ def continuous_eval(estimator, ...@@ -1011,6 +1014,7 @@ def continuous_eval(estimator,
# Terminate eval job when final checkpoint is reached # Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split('-')[1]) current_step = int(os.path.basename(ckpt).split('-')[1])
yield (current_step, eval_results)
if current_step >= train_steps: if current_step >= train_steps:
tf.logging.info( tf.logging.info(
'Evaluation finished after training step %d' % current_step) 'Evaluation finished after training step %d' % current_step)
...@@ -1021,6 +1025,30 @@ def continuous_eval(estimator, ...@@ -1021,6 +1025,30 @@ def continuous_eval(estimator,
'Checkpoint %s no longer exists, skipping checkpoint' % ckpt) 'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
def continuous_eval(estimator,
model_dir,
input_fn,
train_steps,
name,
max_retries=0):
"""Performs continuous evaluation on checkpoints written to a model directory.
Args:
estimator: Estimator object to use for evaluation.
model_dir: Model directory to read checkpoints for continuous evaluation.
input_fn: Input function to use for evaluation.
train_steps: Number of training steps. This is used to infer the last
checkpoint and stop evaluation loop.
name: Namescope for eval summary.
max_retries: Maximum number of times to retry the evaluation on encountering
a tf.errors.InvalidArgumentError. If negative, will always retry the
evaluation.
"""
for current_step, eval_results in continuous_eval_generator(
estimator, model_dir, input_fn, train_steps, name, max_retries):
tf.logging.info('Step %s, Eval results: %s', current_step, eval_results)
def populate_experiment(run_config, def populate_experiment(run_config,
hparams, hparams,
pipeline_config_path, pipeline_config_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment