Commit 80a6318b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325352935
parent b9785623
......@@ -524,13 +524,14 @@ def run_customized_training_loop(
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics:
metric.reset_states()
logging.info('Running evaluation after step: %s.', current_step)
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# We add train_loss here rather than call on_batch_end twice to make
# sure that no duplicated values are generated.
logs['loss'] = train_loss
......@@ -548,6 +549,11 @@ def run_customized_training_loop(
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
# Re-initialize evaluation metric.
eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics:
metric.reset_states()
logging.info('Running final evaluation after training is complete.')
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment