Commit 8dace71e authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342390260
parent 97d84a4c
...@@ -107,9 +107,8 @@ def create_model_fn(input_shape, num_classes, use_float16=False): ...@@ -107,9 +107,8 @@ def create_model_fn(input_shape, num_classes, use_float16=False):
tf.reduce_mean(input_layer), name='mean_input', aggregation='mean') tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) model.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
if use_float16: if use_float16:
model.optimizer = ( model.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
tf.keras.mixed_precision.experimental.LossScaleOptimizer( model.optimizer)
model.optimizer, loss_scale='dynamic'))
return model, sub_model return model, sub_model
return _model_fn return _model_fn
...@@ -198,8 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -198,8 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_gpu_strategy_combinations()) @combinations.generate(eager_gpu_strategy_combinations())
def test_train_eager_mixed_precision(self, distribution): def test_train_eager_mixed_precision(self, distribution):
model_dir = self.create_tempdir().full_path model_dir = self.create_tempdir().full_path
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') tf.keras.mixed_precision.set_global_policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
self._model_fn = create_model_fn( self._model_fn = create_model_fn(
input_shape=[128], num_classes=3, use_float16=True) input_shape=[128], num_classes=3, use_float16=True)
self.run_training( self.run_training(
......
...@@ -151,7 +151,8 @@ def run_bert_classifier(strategy, ...@@ -151,7 +151,8 @@ def run_bert_classifier(strategy,
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return classifier_model, core_model return classifier_model, core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming # tf.keras.losses objects accept optional sample_weight arguments (eg. coming
...@@ -348,7 +349,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config, ...@@ -348,7 +349,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
raise ValueError('Export path is not specified: %s' % model_dir) raise ValueError('Export path is not specified: %s' % model_dir)
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, bert_config,
input_meta_data.get('num_labels', 1), input_meta_data.get('num_labels', 1),
...@@ -370,7 +371,8 @@ def run_bert(strategy, ...@@ -370,7 +371,8 @@ def run_bert(strategy,
"""Run BERT training.""" """Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = ( train_data_size = (
......
...@@ -126,7 +126,8 @@ def run_customized_training(strategy, ...@@ -126,7 +126,8 @@ def run_customized_training(strategy,
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return pretrain_model, core_model return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
...@@ -162,7 +163,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None): ...@@ -162,7 +163,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
logging.info('Training using customized training loop TF 2.0 with distributed' logging.info('Training using customized training loop TF 2.0 with distributed'
'strategy.') 'strategy.')
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
......
...@@ -160,7 +160,7 @@ def get_squad_model_to_predict(strategy, bert_config, checkpoint_path, ...@@ -160,7 +160,7 @@ def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
"""Gets a squad model to make predictions.""" """Gets a squad model to make predictions."""
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
squad_model, _ = bert_models.squad_model( squad_model, _ = bert_models.squad_model(
bert_config, bert_config,
input_meta_data['max_seq_length'], input_meta_data['max_seq_length'],
...@@ -225,7 +225,8 @@ def train_squad(strategy, ...@@ -225,7 +225,8 @@ def train_squad(strategy,
' strategy.') ' strategy.')
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype(),
use_experimental_api=False)
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size'] num_train_examples = input_meta_data['train_data_size']
...@@ -253,7 +254,8 @@ def train_squad(strategy, ...@@ -253,7 +254,8 @@ def train_squad(strategy,
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite(),
use_experimental_api=False)
return squad_model, core_model return squad_model, core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and # Only when explicit_allreduce = True, post_allreduce_callbacks and
...@@ -465,7 +467,7 @@ def export_squad(model_export_path, input_meta_data, bert_config): ...@@ -465,7 +467,7 @@ def export_squad(model_export_path, input_meta_data, bert_config):
if not model_export_path: if not model_export_path:
raise ValueError('Export path is not specified: %s' % model_export_path) raise ValueError('Export path is not specified: %s' % model_export_path)
# Export uses float32 for now, even if training uses mixed precision. # Export uses float32 for now, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.set_global_policy('float32')
squad_model, _ = bert_models.squad_model(bert_config, squad_model, _ = bert_models.squad_model(bert_config,
input_meta_data['max_seq_length']) input_meta_data['max_seq_length'])
model_saving_utils.export_bert_model( model_saving_utils.export_bert_model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment