Commit 55331222 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Orbit] The global_step variable is required and cannot be None.

We use kwargs for all arguments and enforce them should be kwargs.
Reorder args.

PiperOrigin-RevId: 346908658
parent 8de71c4b
......@@ -100,7 +100,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
checkpoint_manager = None
controller = orbit.Controller(
distribution_strategy,
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
......
......@@ -161,9 +161,9 @@ def run(flags_obj):
checkpoint_interval=checkpoint_interval)
resnet_controller = orbit.Controller(
strategy,
runnable,
runnable if not flags_obj.skip_eval else None,
strategy=strategy,
trainer=runnable,
evaluator=runnable if not flags_obj.skip_eval else None,
global_step=runnable.global_step,
steps_per_loop=steps_per_loop,
checkpoint_manager=checkpoint_manager,
......
......@@ -74,10 +74,11 @@ class Controller:
def __init__(
self,
strategy: Optional[tf.distribute.Strategy] = None,
*, # Makes all args keyword only.
global_step: tf.Variable,
trainer: Optional[runner.AbstractTrainer] = None,
evaluator: Optional[runner.AbstractEvaluator] = None,
global_step: Optional[tf.Variable] = None,
strategy: Optional[tf.distribute.Strategy] = None,
# Train related
steps_per_loop: Optional[int] = None,
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
......@@ -93,13 +94,6 @@ class Controller:
recent checkpoint during this `__init__` method.
Args:
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
trainer: An instance of `orbit.AbstractTrainer`, which implements the
inner training loop.
evaluator: An instance of `orbit.AbstractEvaluator`, which implements
evaluation.
global_step: An integer `tf.Variable` storing the global training step
number. Usually this can be obtained from the `iterations` property of
the model's optimizer (e.g. `trainer.optimizer.iterations`). In cases
......@@ -109,6 +103,13 @@ class Controller:
recommended to create the `tf.Variable` inside the distribution strategy
scope, with `aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA` (see
also `orbit.utils.create_global_step()`).
trainer: An instance of `orbit.AbstractTrainer`, which implements the
inner training loop.
evaluator: An instance of `orbit.AbstractEvaluator`, which implements
evaluation.
strategy: An instance of `tf.distribute.Strategy`. If not provided, the
strategy will be initialized from the current in-scope strategy using
`tf.distribute.get_strategy()`.
steps_per_loop: The number of steps to run in each inner loop of training
(passed as the `num_steps` parameter of `trainer.train`).
checkpoint_manager: An instance of `tf.train.CheckpointManager`. If
......@@ -137,7 +138,6 @@ class Controller:
"""
if trainer is None and evaluator is None:
raise ValueError("`trainer` and `evaluator` should not both be `None`.")
if trainer is not None:
if steps_per_loop is None:
raise ValueError(
......@@ -155,9 +155,7 @@ class Controller:
f"`summary interval` ({summary_interval}) must be a multiple "
f"of `steps_per_loop` ({steps_per_loop}).")
if global_step is None:
raise ValueError("`global_step` is required.")
elif not isinstance(global_step, tf.Variable):
if not isinstance(global_step, tf.Variable):
raise ValueError("`global_step` must be a `tf.Variable`.")
self.trainer = trainer
......@@ -185,8 +183,7 @@ class Controller:
self.eval_summary_manager = utils.SummaryManager(
eval_summary_dir, tf.summary.scalar, global_step=self.global_step)
if self.global_step is not None:
tf.summary.experimental.set_step(self.global_step)
tf.summary.experimental.set_step(self.global_step)
# Restores the model if needed.
if self.checkpoint_manager is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment