Commit 8048df8c authored by Ron Shapiro's avatar Ron Shapiro Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 455390116
parent edfe5df3
......@@ -17,7 +17,7 @@
import pprint
import time
from typing import Callable, List, Optional, Union
from typing import Callable, Iterable, Optional, Union
from absl import logging
......@@ -74,13 +74,13 @@ class Controller:
other custom outer loop implementations easy to achieve.
Some additional customization can be achieved by supplying `train_actions` or
`eval_actions` when constructing the `Controller`. These are just lists of
arbitrary callables that are applied by the `Controller` to the output of
train steps (after each inner loop of `steps_per_loop` steps) or an
evaluation. This provides a hook mechanism, enabling things like reporting
metrics to Vizier, model exporting, additional logging, etc. See the
`orbit.actions` package for a small handful of predefined actions and some
utility classes that may be useful in defining your own.
`eval_actions` when constructing the `Controller`. Actions arbitrary callables
that are applied by the `Controller` to the output of train steps (after each
inner loop of `steps_per_loop` steps) or an evaluation. This provides a hook
mechanism, enabling things like reporting metrics to Vizier, model exporting,
additional logging, etc. See the `orbit.actions` package for a small handful
of predefined actions and some utility classes that may be useful in defining
your own.
"""
def __init__(
......@@ -91,8 +91,8 @@ class Controller:
evaluator: Optional[runner.AbstractEvaluator] = None,
strategy: Optional[tf.distribute.Strategy] = None,
# Actions
train_actions: Optional[List[Action]] = None,
eval_actions: Optional[List[Action]] = None,
train_actions: Optional[Iterable[Action]] = None,
eval_actions: Optional[Iterable[Action]] = None,
# Train related
steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
......@@ -125,12 +125,11 @@ class Controller:
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
train_actions: An optional list of `orbit.Action`s to call after each
block of `steps_per_loop` training steps are run. These will be called
with the output of `trainer.train`.
eval_actions: An optional list of `orbit.Action`s to call after each
evaluation. These will be called with the output of
`evaluator.evaluate`.
train_actions: Optional `orbit.Action`s to call after each block of
`steps_per_loop` training steps are run. These will be called with the
output of `trainer.train`.
eval_actions: Optional `orbit.Action`s to call after each evaluation.
These will be called with the output of `evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
......@@ -185,8 +184,8 @@ class Controller:
self.strategy = strategy or tf.distribute.get_strategy()
self.train_actions = train_actions or []
self.eval_actions = eval_actions or []
self.train_actions = () if train_actions is None else tuple(train_actions)
self.eval_actions = () if eval_actions is None else tuple(eval_actions)
self.global_step = global_step
self.checkpoint_manager = checkpoint_manager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment