"...text-generation-inference.git" did not exist on "b6bb1d5160083a011d69c1a32547346a3b4d7d94"
Commit d1ccfbd2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 477004431
parent b681a1b8
...@@ -45,9 +45,11 @@ def run_experiment( ...@@ -45,9 +45,11 @@ def run_experiment(
params: configs.MultiTaskExperimentConfig, params: configs.MultiTaskExperimentConfig,
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
trainer: base_trainer.MultiTaskBaseTrainer = None trainer: base_trainer.MultiTaskBaseTrainer = None,
) -> Union[base_model.MultiTaskBaseModel, best_ckpt_exporter_creator: Optional[Any] = train_utils
Tuple[base_model.MultiTaskBaseModel, Mapping[Any, Any]]]: .maybe_create_best_ckpt_exporter
) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
Mapping[Any, Any]]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -62,6 +64,7 @@ def run_experiment( ...@@ -62,6 +64,7 @@ def run_experiment(
are returned. are returned.
trainer: (optional) A multi-task trainer to use. If none is provided, a trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`. default one will be created based on `params`.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns: Returns:
model: `base_model.MultiTaskBaseModel` instance. model: `base_model.MultiTaskBaseModel` instance.
...@@ -86,8 +89,7 @@ def run_experiment( ...@@ -86,8 +89,7 @@ def run_experiment(
model=model, model=model,
eval_steps=eval_steps, eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
params, model_dir))
else: else:
evaluator = None evaluator = None
...@@ -159,7 +161,10 @@ def run_experiment_with_multitask_eval( ...@@ -159,7 +161,10 @@ def run_experiment_with_multitask_eval(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]: trainer: Optional[core_lib.Trainer] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter,
) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -176,6 +181,7 @@ def run_experiment_with_multitask_eval( ...@@ -176,6 +181,7 @@ def run_experiment_with_multitask_eval(
trainer: the core_lib.Trainer instance. It should be created within the trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'. if `mode` contains 'train'.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns: Returns:
model: `tf.keras.Model` instance. model: `tf.keras.Model` instance.
...@@ -205,8 +211,7 @@ def run_experiment_with_multitask_eval( ...@@ -205,8 +211,7 @@ def run_experiment_with_multitask_eval(
model=model, model=model,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps, eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
params, model_dir))
else: else:
evaluator = None evaluator = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment