"examples/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "48a0c17a3d233dd835d8cdd9508f6f4e3d03dcc6"
Commit 8048df8c authored by Ron Shapiro's avatar Ron Shapiro Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 455390116
parent edfe5df3
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import pprint import pprint
import time import time
from typing import Callable, List, Optional, Union from typing import Callable, Iterable, Optional, Union
from absl import logging from absl import logging
...@@ -74,13 +74,13 @@ class Controller: ...@@ -74,13 +74,13 @@ class Controller:
other custom outer loop implementations easy to achieve. other custom outer loop implementations easy to achieve.
Some additional customization can be achieved by supplying `train_actions` or Some additional customization can be achieved by supplying `train_actions` or
`eval_actions` when constructing the `Controller`. These are just lists of `eval_actions` when constructing the `Controller`. Actions arbitrary callables
arbitrary callables that are applied by the `Controller` to the output of that are applied by the `Controller` to the output of train steps (after each
train steps (after each inner loop of `steps_per_loop` steps) or an inner loop of `steps_per_loop` steps) or an evaluation. This provides a hook
evaluation. This provides a hook mechanism, enabling things like reporting mechanism, enabling things like reporting metrics to Vizier, model exporting,
metrics to Vizier, model exporting, additional logging, etc. See the additional logging, etc. See the `orbit.actions` package for a small handful
`orbit.actions` package for a small handful of predefined actions and some of predefined actions and some utility classes that may be useful in defining
utility classes that may be useful in defining your own. your own.
""" """
def __init__( def __init__(
...@@ -91,8 +91,8 @@ class Controller: ...@@ -91,8 +91,8 @@ class Controller:
evaluator: Optional[runner.AbstractEvaluator] = None, evaluator: Optional[runner.AbstractEvaluator] = None,
strategy: Optional[tf.distribute.Strategy] = None, strategy: Optional[tf.distribute.Strategy] = None,
# Actions # Actions
train_actions: Optional[List[Action]] = None, train_actions: Optional[Iterable[Action]] = None,
eval_actions: Optional[List[Action]] = None, eval_actions: Optional[Iterable[Action]] = None,
# Train related # Train related
steps_per_loop: Optional[int] = None, steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None, checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
...@@ -125,12 +125,11 @@ class Controller: ...@@ -125,12 +125,11 @@ class Controller:
strategy: An instance of `tf.distribute.Strategy`. If not provided, the strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`. `tf.distribute.get_strategy()`.
train_actions: An optional list of `orbit.Action`s to call after each train_actions: Optional `orbit.Action`s to call after each block of
block of `steps_per_loop` training steps are run. These will be called `steps_per_loop` training steps are run. These will be called with the
with the output of `trainer.train`. output of `trainer.train`.
eval_actions: An optional list of `orbit.Action`s to call after each eval_actions: Optional `orbit.Action`s to call after each evaluation.
evaluation. These will be called with the output of These will be called with the output of `evaluator.evaluate`.
`evaluator.evaluate`.
steps_per_loop: The number of steps to run in each inner loop of training steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`). (passed as the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
...@@ -185,8 +184,8 @@ class Controller: ...@@ -185,8 +184,8 @@ class Controller:
self.strategy = strategy or tf.distribute.get_strategy() self.strategy = strategy or tf.distribute.get_strategy()
self.train_actions = train_actions or [] self.train_actions = () if train_actions is None else tuple(train_actions)
self.eval_actions = eval_actions or [] self.eval_actions = () if eval_actions is None else tuple(eval_actions)
self.global_step = global_step self.global_step = global_step
self.checkpoint_manager = checkpoint_manager self.checkpoint_manager = checkpoint_manager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment