"torchvision/vscode:/vscode.git/clone" did not exist on "4028d9f882a7af55436800b7a7167f547f7debfc"
Commit a964b89a authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 380858205
parent 662127aa
...@@ -40,7 +40,8 @@ def run_experiment( ...@@ -40,7 +40,8 @@ def run_experiment(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -56,6 +57,8 @@ def run_experiment( ...@@ -56,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope(). strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -87,7 +90,7 @@ def run_experiment( ...@@ -87,7 +90,7 @@ def run_experiment(
else: else:
checkpoint_manager = None checkpoint_manager = None
controller = orbit.Controller( controller = controller_cls(
strategy=distribution_strategy, strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None, trainer=trainer if 'train' in mode else None,
evaluator=trainer, evaluator=trainer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment