Commit a964b89a authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 380858205
parent 662127aa
......@@ -40,7 +40,8 @@ def run_experiment(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
......@@ -56,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
......@@ -87,7 +90,7 @@ def run_experiment(
else:
checkpoint_manager = None
controller = orbit.Controller(
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment