Commit 69837027 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Uses performance.configure_optimizer to reduce duplicate code.

PiperOrigin-RevId: 301308623
parent 0b7b3b4e
......@@ -24,6 +24,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import model_training_utils
from official.modeling import performance
from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
......@@ -194,8 +195,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
start_logits=start_logits,
end_logits=end_logits)
outputs = strategy.run(
_replicated_step, args=(next(iterator),))
outputs = strategy.run(_replicated_step, args=(next(iterator),))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
all_results = []
......@@ -219,10 +219,7 @@ def train_squad(strategy,
' strategy.')
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils.set_config_v2(FLAGS.enable_xla)
use_float16 = common_flags.use_float16()
if use_float16:
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
num_train_examples = input_meta_data['train_data_size']
......@@ -242,23 +239,14 @@ def train_squad(strategy,
max_seq_length,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)
squad_model.optimizer = optimization.create_optimizer(
FLAGS.learning_rate, steps_per_epoch * epochs, warmup_steps)
if use_float16:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps)
squad_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# The original BERT model does not scale the loss by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment