"torchvision/vscode:/vscode.git/clone" did not exist on "f1d7c92d42324b28c6d9d62094f08811106a5ebb"
Commit 6446619f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 347439073
parent eaf8c8c3
......@@ -20,6 +20,13 @@ from typing import Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure.
......@@ -54,6 +61,30 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
def logging_dir(self) -> str:
return self._logging_dir
@classmethod
def create_optimizer(cls, trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def initialize(self, model: tf.keras.Model):
"""[Optional] A callback function used as CheckpointManager's init_fn.
......
......@@ -19,7 +19,6 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
from typing import Optional
from absl import logging
import gin
......@@ -28,35 +27,9 @@ import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
RuntimeConfig = config_definitions.RuntimeConfig
def create_optimizer(trainer_config: TrainerConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(trainer_config.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
class Recovery:
......
......@@ -61,7 +61,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config,
task,
model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime),
optimizer=task.create_optimizer(config.trainer, config.runtime),
checkpoint_exporter=ckpt_exporter)
return trainer
......@@ -180,7 +180,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config,
task,
model=task.build_model(),
optimizer=trainer_lib.create_optimizer(config.trainer, config.runtime))
optimizer=task.create_optimizer(config.trainer, config.runtime))
trainer.add_recovery(config.trainer, checkpoint_manager=checkpoint_manager)
with self.assertRaises(RuntimeError):
_ = trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
......
......@@ -134,7 +134,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer."""
logging.info('Running default trainer.')
model = task.build_model()
optimizer = base_trainer.create_optimizer(params.trainer, params.runtime)
optimizer = task.create_optimizer(params.trainer, params.runtime)
return trainer_cls(
params,
task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment