"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "bb71654ebe846d97df306b163c086167239431e5"
Commit ab800e79 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Factor out optimizer creation logic from the trainer, so as to avoid access...

Factor out optimizer creation logic from the trainer, so as to avoid access runtime config inside Trainer.
Add validation method on the ExperimentConfig as clients may use other containers instead of the model garden config dataclass.

PiperOrigin-RevId: 336132223
parent 5d92fa8d
...@@ -19,7 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -19,7 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be `StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
from typing import Optional
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -30,6 +30,31 @@ from official.modeling import performance ...@@ -30,6 +30,31 @@ from official.modeling import performance
from official.modeling.hyperparams import config_definitions from official.modeling.hyperparams import config_definitions
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
def create_optimizer(trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
@gin.configurable @gin.configurable
...@@ -40,9 +65,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -40,9 +65,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
config: ExperimentConfig, config: ExperimentConfig,
task: base_task.Task, task: base_task.Task,
model: tf.keras.Model, model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True, train: bool = True,
evaluate: bool = True, evaluate: bool = True,
optimizer=None,
checkpoint_exporter=None): checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
...@@ -51,45 +76,29 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -51,45 +76,29 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
task: A base_task.Task instance. task: A base_task.Task instance.
model: tf.keras.Model instance. If provided, it will be used instead of model: tf.keras.Model instance. If provided, it will be used instead of
building model using task.build_model(). Default to None. building model using task.build_model(). Default to None.
optimizer: tf.optimizers.Optimizer instance.
train: bool, whether or not this trainer will be used for training. train: bool, whether or not this trainer will be used for training.
default to True. default to True.
evaluate: bool, whether or not this trainer will be used for evaluation. evaluate: bool, whether or not this trainer will be used for evaluation.
default to True. default to True.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface. interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._validate_params(config)
self._config = config self._config = config
self._task = task self._task = task
self._model = model self._model = model
if optimizer is None:
opt_factory = optimization.OptimizerFactory(
config.trainer.optimizer_config)
self._optimizer = opt_factory.build_optimizer(
opt_factory.build_learning_rate())
else:
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._optimizer = optimizer
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if config.runtime.loss_scale:
self._optimizer = performance.configure_optimizer(
self._optimizer,
use_float16=config.runtime.mixed_precision_dtype == 'float16',
loss_scale=config.runtime.loss_scale)
# global_step increases by 1 after each training iteration. # global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy() # We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer. # when there is only 1 optimizer.
self._global_step = orbit.utils.create_global_step() self._global_step = orbit.utils.create_global_step()
if hasattr(self.model, 'checkpoint_items'): if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items checkpoint_items = self.model.checkpoint_items
else: else:
checkpoint_items = {} checkpoint_items = {}
...@@ -99,9 +108,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -99,9 +108,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
optimizer=self.optimizer, optimizer=self.optimizer,
**checkpoint_items) **checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32) self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean( self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32) "validation_loss", dtype=tf.float32)
self._train_metrics = self.task.build_metrics( self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics( self._validation_metrics = self.task.build_metrics(
...@@ -128,6 +137,34 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -128,6 +137,34 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
options=orbit.StandardEvaluatorOptions( options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function)) use_tf_function=config.trainer.eval_tf_function))
def _validate_params(self, config):
r"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
\trainer
\task
\train_data
\validation_data
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
"""
if not hasattr(config, "trainer"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `trainer`.")
if not hasattr(config, "task"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task`.")
if not hasattr(config.task, "train_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.train_data`.")
if not hasattr(config.task, "validation_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.validation_data`.")
@property @property
def strategy(self): def strategy(self):
return self._strategy return self._strategy
...@@ -194,9 +231,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -194,9 +231,9 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
logs[metric.name] = metric.result() logs[metric.name] = metric.result()
metric.reset_states() metric.reset_states()
if callable(self.optimizer.learning_rate): if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(self.global_step) logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else: else:
logs['learning_rate'] = self.optimizer.learning_rate logs["learning_rate"] = self.optimizer.learning_rate
return logs return logs
def train_step(self, iterator): def train_step(self, iterator):
...@@ -244,8 +281,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator): ...@@ -244,8 +281,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
self._checkpoint_exporter.maybe_export_checkpoint( self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, logs, self.global_step.numpy()) self.checkpoint, logs, self.global_step.numpy())
metric_name = self.config.trainer.best_checkpoint_eval_metric metric_name = self.config.trainer.best_checkpoint_eval_metric
logs['best_' + metric_name] = self._checkpoint_exporter.best_ckpt_logs[ logs["best_" +
metric_name] metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
return logs return logs
......
...@@ -54,9 +54,15 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -54,9 +54,15 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
} }
}))) })))
def create_test_trainer(self, config): def create_test_trainer(self, config, model_dir=None):
task = mock_task.MockTask() task = mock_task.MockTask(config.task, logging_dir=model_dir)
trainer = trainer_lib.Trainer(config, task, model=task.build_model()) ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime),
checkpoint_exporter=ckpt_exporter)
return trainer return trainer
@combinations.generate(all_strategy_combinations()) @combinations.generate(all_strategy_combinations())
...@@ -121,13 +127,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -121,13 +127,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
} }
}))) })))
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
task = mock_task.MockTask(config.task, logging_dir=model_dir) trainer = self.create_test_trainer(config, model_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
checkpoint_exporter=ckpt_exporter)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue( self.assertTrue(
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import json import json
import os import os
import pprint import pprint
from typing import Any, List from typing import Any, List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -34,20 +34,22 @@ from official.modeling.hyperparams import config_definitions ...@@ -34,20 +34,22 @@ from official.modeling.hyperparams import config_definitions
def create_trainer(params: config_definitions.ExperimentConfig, def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task, task: base_task.Task,
model_dir: str,
train: bool, train: bool,
evaluate: bool, evaluate: bool,
checkpoint_exporter: Any = None) -> base_trainer.Trainer: checkpoint_exporter: Any = None,
model_dir: Optional[str] = None) -> base_trainer.Trainer:
"""Create trainer.""" """Create trainer."""
del model_dir del model_dir
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model() model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
trainer = base_trainer.Trainer( trainer = base_trainer.Trainer(
params, params,
task, task,
model=model,
optimizer=optimizer,
train=train, train=train,
evaluate=evaluate, evaluate=evaluate,
model=model,
checkpoint_exporter=checkpoint_exporter) checkpoint_exporter=checkpoint_exporter)
return trainer return trainer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment