Commit bf6a29e4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Return the model instance from train_lib.run_experiment() for convenience.

PiperOrigin-RevId: 327324127
parent c31ae6da
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""TFM common training driver library.""" """TFM common training driver library."""
import os import os
from typing import Any, Mapping from typing import Any, Mapping, Tuple
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -34,7 +34,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -34,7 +34,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
params: config_definitions.ExperimentConfig, params: config_definitions.ExperimentConfig,
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True) -> Mapping[str, Any]: save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -49,8 +50,10 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -49,8 +50,10 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
Returns: Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True, A 2-tuple of (model, eval_logs).
othewise, returns {}. model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
...@@ -106,7 +109,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -106,7 +109,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
if run_post_eval: if run_post_eval:
with distribution_strategy.scope(): with distribution_strategy.scope():
return trainer.evaluate( return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) tf.convert_to_tensor(params.trainer.validation_steps))
else: else:
return {} return trainer.model, {}
...@@ -83,7 +83,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase): ...@@ -83,7 +83,7 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
logs = train_lib.run_experiment( _, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=flag_mode, mode=flag_mode,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment