"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "e515f026361ff36f0ffac8ce01edac206e27505c"
Commit e3704ce2 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds trainer and checkpoint exporter as the arguments of the run_experiment functions.

PiperOrigin-RevId: 368778443
parent 85a6db17
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, Mapping, Tuple from typing import Any, Mapping, Tuple, Optional
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -23,21 +23,23 @@ import orbit ...@@ -23,21 +23,23 @@ import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(
task: base_task.Task, distribution_strategy: tf.distribute.Strategy,
mode: str, task: base_task.Task,
params: config_definitions.ExperimentConfig, mode: str,
model_dir: str, params: config_definitions.ExperimentConfig,
run_post_eval: bool = False, model_dir: str,
save_summary: bool = True) \ run_post_eval: bool = False,
-> Tuple[tf.keras.Model, Mapping[str, Any]]: save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
Args: Args:
...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
run_post_eval: Whether to run post eval once after training, metrics logs run_post_eval: Whether to run post eval once after training, metrics logs
are returned. are returned.
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -59,13 +63,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -59,13 +63,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
""" """
with distribution_strategy.scope(): with distribution_strategy.scope():
trainer = train_utils.create_trainer( if not trainer:
params, trainer = train_utils.create_trainer(
task, params,
train='train' in mode, task,
evaluate=('eval' in mode) or run_post_eval, train='train' in mode,
checkpoint_exporter=maybe_create_best_ckpt_exporter( evaluate=('eval' in mode) or run_post_eval,
params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment