Commit 8ca78e39 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds an option to control evaluation in training loop.

PiperOrigin-RevId: 313708781
parent 8ea058b9
...@@ -343,6 +343,7 @@ class DistributedExecutor(object): ...@@ -343,6 +343,7 @@ class DistributedExecutor(object):
SummaryWriter] = SummaryWriter, SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None, init_checkpoint: Callable[[tf.keras.Model], Any] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None, custom_callbacks: List[tf.keras.callbacks.Callback] = None,
continuous_eval: bool = False,
save_config: bool = True): save_config: bool = True):
"""Runs distributed training. """Runs distributed training.
...@@ -362,8 +363,10 @@ class DistributedExecutor(object): ...@@ -362,8 +363,10 @@ class DistributedExecutor(object):
custom_callbacks: A list of Keras Callbacks objects to run during custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`, training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training. methods are invoked during training.
continuous_eval: If `True`, will continously run evaluation on every
available checkpoints. If `False`, will do the evaluation once after the
final step.
save_config: bool. Whether to save params to model_dir. save_config: bool. Whether to save params to model_dir.
Returns: Returns:
The training loss and eval metrics. The training loss and eval metrics.
""" """
...@@ -414,6 +417,7 @@ class DistributedExecutor(object): ...@@ -414,6 +417,7 @@ class DistributedExecutor(object):
# input pipeline ops in worker task. # input pipeline ops in worker task.
train_iterator = self._get_input_iterator(train_input_fn, strategy) train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None train_loss = None
train_metric_result = None
eval_metric_result = None eval_metric_result = None
tf.keras.backend.set_learning_phase(1) tf.keras.backend.set_learning_phase(1)
with strategy.scope(): with strategy.scope():
...@@ -530,7 +534,7 @@ class DistributedExecutor(object): ...@@ -530,7 +534,7 @@ class DistributedExecutor(object):
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step last_save_checkpoint_step = current_step
if test_step: if continuous_eval and current_step < total_steps and test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy) eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step, eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator) eval_metric, eval_iterator)
...@@ -562,7 +566,7 @@ class DistributedExecutor(object): ...@@ -562,7 +566,7 @@ class DistributedExecutor(object):
self.train_summary_writer.close() self.train_summary_writer.close()
self.eval_summary_writer.close() self.eval_summary_writer.close()
return train_loss, eval_metric_result return train_metric_result, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric, def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator): test_iterator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment