Commit 2886fa0a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 463812468
parent 81fb5b06
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, List, Optional, Tuple from typing import Any, List, Mapping, Optional, Tuple, Union
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -44,8 +44,10 @@ def run_experiment( ...@@ -44,8 +44,10 @@ def run_experiment(
mode: str, mode: str,
params: configs.MultiTaskExperimentConfig, params: configs.MultiTaskExperimentConfig,
model_dir: str, model_dir: str,
run_post_eval: bool = False,
trainer: base_trainer.MultiTaskBaseTrainer = None trainer: base_trainer.MultiTaskBaseTrainer = None
) -> base_model.MultiTaskBaseModel: ) -> Union[base_model.MultiTaskBaseModel,
Tuple[base_model.MultiTaskBaseModel, Mapping[Any, Any]]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -56,6 +58,8 @@ def run_experiment( ...@@ -56,6 +58,8 @@ def run_experiment(
or 'continuous_eval'. or 'continuous_eval'.
params: ExperimentConfig instance. params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries. model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
trainer: (optional) A multi-task trainer to use. If none is provided, a trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`. default one will be created based on `params`.
...@@ -139,7 +143,11 @@ def run_experiment( ...@@ -139,7 +143,11 @@ def run_experiment(
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
return model if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model
def run_experiment_with_multitask_eval( def run_experiment_with_multitask_eval(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment