"vscode:/vscode.git/clone" did not exist on "b28cbf7125fa3646a1aefb15dcb5c9359c9e664d"
Commit d1ccfbd2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 477004431
parent b681a1b8
......@@ -45,9 +45,11 @@ def run_experiment(
params: configs.MultiTaskExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
trainer: base_trainer.MultiTaskBaseTrainer = None
) -> Union[base_model.MultiTaskBaseModel,
Tuple[base_model.MultiTaskBaseModel, Mapping[Any, Any]]]:
trainer: base_trainer.MultiTaskBaseTrainer = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter
) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
Mapping[Any, Any]]]:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -62,6 +64,7 @@ def run_experiment(
are returned.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
......@@ -86,8 +89,7 @@ def run_experiment(
model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
else:
evaluator = None
......@@ -159,7 +161,10 @@ def run_experiment_with_multitask_eval(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]:
trainer: Optional[core_lib.Trainer] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter,
) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -176,6 +181,7 @@ def run_experiment_with_multitask_eval(
trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns:
model: `tf.keras.Model` instance.
......@@ -205,8 +211,7 @@ def run_experiment_with_multitask_eval(
model=model,
global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
else:
evaluator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment