Commit ba8ad4f5 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Use nonexperimental mixed precision API for official models.

For all modified calls to set_mixed_precision_policy(), the loss_scale argument was removed, as it cannot be passed if the nonexperimental API is used. For all such callers, the loss_scale is later used to explicitly create a LossScaleOptimizer, so removing the argument has no impact.

Switching to the non-experimental LossScaleOptimizer has no effect, as it has near identical behavior and all isinstance checks within the official models check for the non-experimental version.

PiperOrigin-RevId: 368101975
parent e6cda015
......@@ -80,8 +80,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale,
use_experimental_api=False)
loss_scale=runtime_config.loss_scale)
return optimizer
......
......@@ -277,9 +277,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
'learning_rate': {
'type': 'constant'
},
'use_experimental_api': {
'type': False
},
})))
trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16':
......
......@@ -45,9 +45,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -151,8 +151,7 @@ def run_bert_classifier(strategy,
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
......@@ -371,8 +370,7 @@ def run_bert(strategy,
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = (
......
......@@ -126,8 +126,7 @@ def run_customized_training(strategy,
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
use_graph_rewrite=common_flags.use_graph_rewrite())
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......@@ -163,8 +162,7 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
performance.set_mixed_precision_policy(common_flags.dtype())
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
......
......@@ -225,8 +225,7 @@ def train_squad(strategy,
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
......@@ -254,8 +253,7 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and
......
......@@ -107,9 +107,7 @@ def run_continuous_finetune(
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -121,9 +121,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -45,9 +45,7 @@ def main(_):
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale,
use_experimental_api=True)
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
......
......@@ -176,8 +176,7 @@ class TransformerTask(object):
else:
logging.info("Not using any distribution strategy.")
performance.set_mixed_precision_policy(params["dtype"],
use_experimental_api=False)
performance.set_mixed_precision_policy(params["dtype"])
@property
def use_tpu(self):
......@@ -443,8 +442,7 @@ class TransformerTask(object):
use_float16=params["dtype"] == tf.float16,
use_graph_rewrite=self.flags_obj.fp16_implementation == "graph_rewrite",
loss_scale=flags_core.get_loss_scale(
self.flags_obj, default_for_fp16="dynamic"),
use_experimental_api=False)
self.flags_obj, default_for_fp16="dynamic"))
return opt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment