Commit 9cfbc813 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds trainer and checkpoint exporter as the arguments of the run_experiment functions.

PiperOrigin-RevId: 368778443
parent b9599c26
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, Mapping, Tuple from typing import Any, Mapping, Tuple, Optional
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -23,21 +23,23 @@ import orbit ...@@ -23,21 +23,23 @@ import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(
task: base_task.Task, distribution_strategy: tf.distribute.Strategy,
mode: str, task: base_task.Task,
params: config_definitions.ExperimentConfig, mode: str,
model_dir: str, params: config_definitions.ExperimentConfig,
run_post_eval: bool = False, model_dir: str,
save_summary: bool = True) \ run_post_eval: bool = False,
-> Tuple[tf.keras.Model, Mapping[str, Any]]: save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
run_post_eval: Whether to run post eval once after training, metrics logs run_post_eval: Whether to run post eval once after training, metrics logs
are returned. are returned.
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -59,13 +63,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -59,13 +63,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
trainer = train_utils.create_trainer( if not trainer:
params, trainer = train_utils.create_trainer(
task, params,
train='train' in mode, task,
evaluate=('eval' in mode) or run_post_eval, train='train' in mode,
checkpoint_exporter=maybe_create_best_ckpt_exporter( evaluate=('eval' in mode) or run_post_eval,
params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment