Commit 0adbb1c9 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

adding automatic mixed precision to SQuAD

parent c162f7ab
......@@ -226,6 +226,14 @@ def train_squad(strategy,
squad_model.optimizer = (
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
squad_model.optimizer, loss_scale=common_flags.get_loss_scale()))
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
squad_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
squad_model.optimizer)
return squad_model, core_model
# The original BERT model does not scale the loss by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment