Commit 8dace71e authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342390260
parent 97d84a4c
......@@ -107,9 +107,8 @@ def create_model_fn(input_shape, num_classes, use_float16=False):
tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
if use_float16:
model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
model.optimizer, loss_scale='dynamic'))
model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
model.optimizer)
return model, sub_model
return _model_fn
......@@ -198,8 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_gpu_strategy_combinations())
def test_train_eager_mixed_precision(self, distribution):
model_dir = self.create_tempdir().full_path
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
tf.keras.mixed_precision.set_global_policy('mixed_float16')
self._model_fn = create_model_fn(
input_shape=[128], num_classes=3, use_float16=True)
self.run_training(
......
......@@ -151,7 +151,8 @@ def run_bert_classifier(strategy,
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
......@@ -348,7 +349,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
raise ValueError('Export path is not specified: %s' % model_dir)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
tf.keras.mixed_precision.set_global_policy('float32')
classifier_model = bert_models.classifier_model(
bert_config,
input_meta_data.get('num_labels', 1),
......@@ -370,7 +371,8 @@ def run_bert(strategy,
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = (
......
......@@ -126,7 +126,8 @@ def run_customized_training(strategy,
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......@@ -162,7 +163,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype())
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
......
......@@ -160,7 +160,7 @@ def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
"""Gets a squad model to make predictions."""
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
tf.keras.mixed_precision.set_global_policy('float32')
squad_model, _ = bert_models.squad_model(
bert_config,
input_meta_data['max_seq_length'],
......@@ -225,7 +225,8 @@ def train_squad(strategy,
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
......@@ -253,7 +254,8 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and
......@@ -465,7 +467,7 @@ def export_squad(model_export_path, input_meta_data, bert_config):
if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
tf.keras.mixed_precision.set_global_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment