"...text-generation-inference.git" did not exist on "46aeb0860dae0c5a1e5990dff50f8d381fddce61"
Commit 7a8cbff1 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Cleanup] change create_optimizer interface to be optimization config and runtime.

PiperOrigin-RevId: 351524712
parent 538a809f
...@@ -24,7 +24,7 @@ from official.core import config_definitions ...@@ -24,7 +24,7 @@ from official.core import config_definitions
from official.modeling import optimization from official.modeling import optimization
from official.modeling import performance from official.modeling import performance
TrainerConfig = config_definitions.TrainerConfig OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
...@@ -62,18 +62,18 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -62,18 +62,18 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
return self._logging_dir return self._logging_dir
@classmethod @classmethod
def create_optimizer(cls, trainer_config: TrainerConfig, def create_optimizer(cls, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None): runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations. """Creates an TF optimizer from configurations.
Args: Args:
trainer_config: the parameters of the trainer. optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime. runtime_config: the parameters of the runtime.
Returns: Returns:
A tf.optimizers.Optimizer object. A tf.optimizers.Optimizer object.
""" """
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config) opt_factory = optimization.OptimizerFactory(optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps # Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations. # avoiding overflow/underflow for float16 computations.
......
...@@ -59,7 +59,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -59,7 +59,8 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config, config,
task, task,
model=task.build_model(), model=task.build_model(),
optimizer=task.create_optimizer(config.trainer, config.runtime), optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime),
checkpoint_exporter=ckpt_exporter) checkpoint_exporter=ckpt_exporter)
return trainer return trainer
...@@ -189,15 +190,18 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -189,15 +190,18 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
} }
}))) })))
task = mock_task.MockTask(config.task, logging_dir=model_dir) task = mock_task.MockTask(config.task, logging_dir=model_dir)
def build_losses(labels, model_outputs, aux_losses=None): def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses task.build_losses = build_losses
trainer = trainer_lib.Trainer( trainer = trainer_lib.Trainer(
config, config,
task, task,
model=task.build_model(), model=task.build_model(),
optimizer=task.create_optimizer(config.trainer, config.runtime)) optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager) trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32)) _ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
......
...@@ -134,7 +134,8 @@ def create_trainer(params: config_definitions.ExperimentConfig, ...@@ -134,7 +134,8 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer.""" """Create trainer."""
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model() model = task.build_model()
optimizer = task.create_optimizer(params.trainer, params.runtime) optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
return trainer_cls( return trainer_cls(
params, params,
task, task,
......
...@@ -21,10 +21,9 @@ from official.core import base_task ...@@ -21,10 +21,9 @@ from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.core import task_factory from official.core import task_factory
from official.modeling import optimization from official.modeling import optimization
from official.modeling import performance
from official.modeling.multitask import configs from official.modeling.multitask import configs
TrainerConfig = config_definitions.TrainerConfig OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
...@@ -105,28 +104,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -105,28 +104,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
return self._task_weights[task_name] return self._task_weights[task_name]
@classmethod @classmethod
def create_optimizer(cls, trainer_config: TrainerConfig, def create_optimizer(cls,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None): runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations. return base_task.Task.create_optimizer(
optimizer_config=optimizer_config, runtime_config=runtime_config)
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def joint_train_step(self, task_inputs, multi_task_model, optimizer, def joint_train_step(self, task_inputs, multi_task_model, optimizer,
task_metrics): task_metrics):
......
...@@ -49,7 +49,8 @@ def run_experiment_wtih_multitask_eval( ...@@ -49,7 +49,8 @@ def run_experiment_wtih_multitask_eval(
is_training = 'train' in mode is_training = 'train' in mode
is_eval = 'eval' in mode is_eval = 'eval' in mode
with distribution_strategy.scope(): with distribution_strategy.scope():
optimizer = train_task.create_optimizer(params.trainer, params.runtime) optimizer = train_task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
model = train_task.build_model() model = train_task.build_model()
if is_training: if is_training:
trainer = core_lib.Trainer( trainer = core_lib.Trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment