"doc/git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "8bb58982849818a4439a9cc40b2c5fdf9db34c53"
Commit 6446619f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347439073
parent eaf8c8c3
...@@ -20,6 +20,13 @@ from typing import Optional ...@@ -20,6 +20,13 @@ from typing import Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
class Task(tf.Module, metaclass=abc.ABCMeta): class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -54,6 +61,30 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -54,6 +61,30 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
def logging_dir(self) -> str: def logging_dir(self) -> str:
return self._logging_dir return self._logging_dir
@classmethod
def create_optimizer(cls, trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""[Optional] A callback function used as CheckpointManager's init_fn. """[Optional] A callback function used as CheckpointManager's init_fn.
......
...@@ -19,7 +19,6 @@ The base trainer implements the Orbit `StandardTrainable` and ...@@ -19,7 +19,6 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be `StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks. interchangable and independent on model architectures and tasks.
""" """
from typing import Optional
from absl import logging from absl import logging
import gin import gin
...@@ -28,35 +27,9 @@ import tensorflow as tf ...@@ -28,35 +27,9 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
ExperimentConfig = config_definitions.ExperimentConfig ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
def create_optimizer(trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
class Recovery: class Recovery:
......
...@@ -61,7 +61,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -61,7 +61,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config, config,
task, task,
model=task.build_model(), model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime), optimizer=task.create_optimizer(config.trainer, config.runtime),
checkpoint_exporter=ckpt_exporter) checkpoint_exporter=ckpt_exporter)
return trainer return trainer
...@@ -180,7 +180,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -180,7 +180,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config, config,
task, task,
model=task.build_model(), model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime)) optimizer=task.create_optimizer(config.trainer, config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager) trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32)) _ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
......
...@@ -134,7 +134,7 @@ def create_trainer(params: config_definitions.ExperimentConfig, ...@@ -134,7 +134,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer.""" """Create trainer."""
logging.info('Running default trainer.') logging.info('Running default trainer.')
model = task.build_model() model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime) optimizer = task.create_optimizer(params.trainer, params.runtime)
return trainer_cls( return trainer_cls(
params, params,
task, task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment