Commit e3704ce2 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds trainer and checkpoint exporter as the arguments of the run_experiment functions.

PiperOrigin-RevId: 368778443
parent 85a6db17
......@@ -15,7 +15,7 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
from typing import Any, Mapping, Tuple, Optional
# Import libraries
from absl import logging
......@@ -23,21 +23,23 @@ import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy,
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
Returns:
A 2-tuple of (model, eval_logs).
......@@ -59,6 +63,7 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
"""
with distribution_strategy.scope():
if not trainer:
trainer = train_utils.create_trainer(
params,
task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment