Commit 9cfbc813 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds trainer and checkpoint exporter as the arguments of the run_experiment functions.

PiperOrigin-RevId: 368778443
parent b9599c26
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, Mapping, Tuple from typing import Any, Mapping, Tuple, Optional
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -23,21 +23,23 @@ import orbit ...@@ -23,21 +23,23 @@ import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task, task: base_task.Task,
mode: str, mode: str,
params: config_definitions.ExperimentConfig, params: config_definitions.ExperimentConfig,
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True) \ save_summary: bool = True,
-> Tuple[tf.keras.Model, Mapping[str, Any]]: trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
run_post_eval: Whether to run post eval once after training, metrics logs run_post_eval: Whether to run post eval once after training, metrics logs
are returned. are returned.
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -59,6 +63,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -59,6 +63,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
if not trainer:
trainer = train_utils.create_trainer( trainer = train_utils.create_trainer(
params, params,
task, task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment