Commit 69837027 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Uses performance.configure_optimizer to reduce duplicate code.

PiperOrigin-RevId: 301308623
parent 0b7b3b4e
...@@ -24,6 +24,7 @@ from absl import logging ...@@ -24,6 +24,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
...@@ -194,8 +195,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -194,8 +195,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits) end_logits=end_logits)
outputs = strategy.run( outputs = strategy.run(_replicated_step, args=(next(iterator),))
_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs) return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = [] all_results = []
...@@ -219,10 +219,7 @@ def train_squad(strategy, ...@@ -219,10 +219,7 @@ def train_squad(strategy,
' strategy.') ' strategy.')
# Enables XLA in Session Config. Should not be set for TPU. # Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla) keras_utils.set_config_v2(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size'] num_train_examples = input_meta_data['train_data_size']
...@@ -242,23 +239,14 @@ def train_squad(strategy, ...@@ -242,23 +239,14 @@ def train_squad(strategy,
max_seq_length, max_seq_length,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable) hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(FLAGS.learning_rate,
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps) steps_per_epoch * epochs,
if use_float16: warmup_steps)
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call squad_model.optimizer = performance.configure_optimizer(
# compile(), we must wrap the optimizer manually. optimizer,
squad_model.optimizer = ( use_float16=common_flags.use_float16(),
tf.keras.mixed_precision.experimental.LossScaleOptimizer( use_graph_rewrite=common_flags.use_graph_rewrite())
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model return squad_model, core_model
# The original BERT model does not scale the loss by # The original BERT model does not scale the loss by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment