Commit 90dedf26 authored by Ruoxin Sang's avatar Ruoxin Sang Committed by A. Unique TensorFlower
Browse files

Allow `steps_per_loop` in Controller to be passed as a callable.

PiperOrigin-RevId: 466412169
parent db19ab9b
...@@ -94,7 +94,7 @@ class Controller: ...@@ -94,7 +94,7 @@ class Controller:
train_actions: Optional[Iterable[Action]] = None, train_actions: Optional[Iterable[Action]] = None,
eval_actions: Optional[Iterable[Action]] = None, eval_actions: Optional[Iterable[Action]] = None,
# Train related # Train related
steps_per_loop: Optional[int] = None, steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None, checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
# Summary related # Summary related
summary_interval: Optional[int] = None, summary_interval: Optional[int] = None,
...@@ -130,8 +130,11 @@ class Controller: ...@@ -130,8 +130,11 @@ class Controller:
output of `trainer.train`. output of `trainer.train`.
eval_actions: Optional `orbit.Action`s to call after each evaluation. eval_actions: Optional `orbit.Action`s to call after each evaluation.
These will be called with the output of `evaluator.evaluate`. These will be called with the output of `evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training steps_per_loop: Optional integer to indicate the number of steps to run in
(passed as the `num_steps` parameter of `trainer.train`). each inner loop of training (passed as the `num_steps` parameter of
`trainer.train`). It can be also a callable which takes the current
global step value as input and returns the number of steps to run as
output.
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
provided and there are checkpoints in the associated model directory, provided and there are checkpoints in the associated model directory,
the model will be restored from the most recent checkpoint inside this the model will be restored from the most recent checkpoint inside this
...@@ -152,7 +155,7 @@ class Controller: ...@@ -152,7 +155,7 @@ class Controller:
Raises: Raises:
ValueError: If both `trainer` and `evaluator` are `None`. ValueError: If both `trainer` and `evaluator` are `None`.
ValueError: If `steps_per_loop` is not a positive integer. ValueError: If `steps_per_loop` is not a positive integer or a callable.
ValueError: If `summary_interval` is not a positive integer or is not ValueError: If `summary_interval` is not a positive integer or is not
divisible by `steps_per_loop`. divisible by `steps_per_loop`.
""" """
...@@ -163,15 +166,18 @@ class Controller: ...@@ -163,15 +166,18 @@ class Controller:
if steps_per_loop is None: if steps_per_loop is None:
raise ValueError( raise ValueError(
"`steps_per_loop` is required when `trainer` is provided.") "`steps_per_loop` is required when `trainer` is provided.")
elif not isinstance(steps_per_loop, int) or steps_per_loop < 1: elif not callable(steps_per_loop) and (
not isinstance(steps_per_loop, int) or steps_per_loop < 1):
raise ValueError( raise ValueError(
f"`steps_per_loop` ({steps_per_loop}) must be a positive integer.") f"`steps_per_loop` ({steps_per_loop}) must be a positive integer "
"or a callable.")
if summary_interval is not None: if summary_interval is not None:
if summary_interval <= 0: if summary_interval <= 0:
raise ValueError( raise ValueError(
f"`summary_interval` ({summary_interval}) must be larger than 0.") f"`summary_interval` ({summary_interval}) must be larger than 0.")
elif summary_interval % steps_per_loop != 0: elif not callable(steps_per_loop) and (summary_interval % steps_per_loop
!= 0):
raise ValueError( raise ValueError(
f"`summary interval` ({summary_interval}) must be a multiple " f"`summary interval` ({summary_interval}) must be a multiple "
f"of `steps_per_loop` ({steps_per_loop}).") f"of `steps_per_loop` ({steps_per_loop}).")
...@@ -192,10 +198,10 @@ class Controller: ...@@ -192,10 +198,10 @@ class Controller:
if self.trainer is not None: if self.trainer is not None:
self.step_timer = None self.step_timer = None
self.steps_per_loop = steps_per_loop
self.summary_interval = summary_interval self.summary_interval = summary_interval
self.summary_manager = utils.SummaryManager( self.summary_manager = utils.SummaryManager(
summary_dir, tf.summary.scalar, global_step=self.global_step) summary_dir, tf.summary.scalar, global_step=self.global_step)
self._steps_per_loop = steps_per_loop
if self.evaluator is not None: if self.evaluator is not None:
eval_summary_dir = eval_summary_dir or summary_dir eval_summary_dir = eval_summary_dir or summary_dir
...@@ -316,9 +322,6 @@ class Controller: ...@@ -316,9 +322,6 @@ class Controller:
results in a shorter inner loop than specified by `steps_per_loop` results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is setting. If None, evaluation will only be performed after training is
complete. complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
""" """
self._require("trainer", for_method="train_and_evaluate") self._require("trainer", for_method="train_and_evaluate")
self._require("evaluator", for_method="train_and_evaluate") self._require("evaluator", for_method="train_and_evaluate")
...@@ -410,6 +413,13 @@ class Controller: ...@@ -410,6 +413,13 @@ class Controller:
self._require("checkpoint_manager", for_method="save_checkpoint") self._require("checkpoint_manager", for_method="save_checkpoint")
self._maybe_save_checkpoint(check_interval=False) self._maybe_save_checkpoint(check_interval=False)
@property
def steps_per_loop(self):
"""Returns current steps_per_loop value in a training loop."""
if callable(self._steps_per_loop):
return self._steps_per_loop(self.global_step.numpy())
return self._steps_per_loop
def _train_n_steps(self, num_steps: int): def _train_n_steps(self, num_steps: int):
"""Runs training for `num_steps` steps. """Runs training for `num_steps` steps.
......
...@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -770,6 +770,32 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self.assertIn("eval_loss", output) self.assertIn("eval_loss", output)
self.assertGreaterEqual(output["eval_loss"], 0) self.assertGreaterEqual(output["eval_loss"], 0)
def test_step_per_loop_callable(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
def steps_per_loop_fn(global_step):
if global_step > 4:
return 4
return 2
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=steps_per_loop_fn,
checkpoint_manager=checkpoint_manager,
)
test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment