Commit ba8ad4f5 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental mixed precision API for official models.

For all modified calls to set_mixed_precision_policy(), the loss_scale argument was removed, as it cannot be passed if the nonexperimental API is used. For all such callers, the loss_scale is later used to explicitly create a LossScaleOptimizer, so removing the argument has no impact.

Switching to the non-experimental LossScaleOptimizer has no effect, as it has near identical behavior and all isinstance checks within the official models check for the non-experimental version.

PiperOrigin-RevId: 368101975
parent e6cda015
...@@ -80,8 +80,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -80,8 +80,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale, loss_scale=runtime_config.loss_scale)
use_experimental_api=False)
return optimizer return optimizer
......
...@@ -277,9 +277,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -277,9 +277,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
'learning_rate': { 'learning_rate': {
'type': 'constant' 'type': 'constant'
}, },
'use_experimental_api': {
'type': False
},
}))) })))
trainer = self.create_test_trainer(config) trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype != 'float16':
......
...@@ -45,9 +45,7 @@ def main(_): ...@@ -45,9 +45,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -151,8 +151,7 @@ def run_bert_classifier(strategy, ...@@ -151,8 +151,7 @@ def run_bert_classifier(strategy,
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(), use_graph_rewrite=common_flags.use_graph_rewrite())
use_experimental_api=False)
return classifier_model, core_model return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
...@@ -371,8 +370,7 @@ def run_bert(strategy, ...@@ -371,8 +370,7 @@ def run_bert(strategy,
"""Run BERT training.""" """Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype(), performance.set_mixed_precision_policy(common_flags.dtype())
use_experimental_api=False)
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = ( train_data_size = (
......
...@@ -126,8 +126,7 @@ def run_customized_training(strategy, ...@@ -126,8 +126,7 @@ def run_customized_training(strategy,
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(), use_graph_rewrite=common_flags.use_graph_rewrite())
use_experimental_api=False)
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
...@@ -163,8 +162,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None): ...@@ -163,8 +162,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
logging.info('Training using customized training loop TF 2.0 with distributed' logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.') 'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype(), performance.set_mixed_precision_policy(common_flags.dtype())
use_experimental_api=False)
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
......
...@@ -225,8 +225,7 @@ def train_squad(strategy, ...@@ -225,8 +225,7 @@ def train_squad(strategy,
' strategy.') ' strategy.')
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype(), performance.set_mixed_precision_policy(common_flags.dtype())
use_experimental_api=False)
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size'] num_train_examples = input_meta_data['train_data_size']
...@@ -254,8 +253,7 @@ def train_squad(strategy, ...@@ -254,8 +253,7 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(), use_graph_rewrite=common_flags.use_graph_rewrite())
use_experimental_api=False)
return squad_model, core_model return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
......
...@@ -107,9 +107,7 @@ def run_continuous_finetune( ...@@ -107,9 +107,7 @@ def run_continuous_finetune(
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -121,9 +121,7 @@ def main(_): ...@@ -121,9 +121,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -45,9 +45,7 @@ def main(_): ...@@ -45,9 +45,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16 # dtype is float16
if params.runtime.mixed_precision_dtype: if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype, performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
params.runtime.loss_scale,
use_experimental_api=True)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
......
...@@ -176,8 +176,7 @@ class TransformerTask(object): ...@@ -176,8 +176,7 @@ class TransformerTask(object):
else: else:
logging.info("Not using any distribution strategy.") logging.info("Not using any distribution strategy.")
performance.set_mixed_precision_policy(params["dtype"], performance.set_mixed_precision_policy(params["dtype"])
use_experimental_api=False)
@property @property
def use_tpu(self): def use_tpu(self):
...@@ -443,8 +442,7 @@ class TransformerTask(object): ...@@ -443,8 +442,7 @@ class TransformerTask(object):
use_float16=params["dtype"] == tf.float16, use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite", use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale( loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"), self.flags_obj, default_for_fp16="dynamic"))
use_experimental_api=False)
return opt return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment