Commit c162f7ab authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

adding automatic mixed precision to BERT classifier task

parent 0f6ff657
......@@ -116,6 +116,14 @@ def run_customized_training(strategy,
max_seq_length))
classifier_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
classifier_model.optimizer)
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes, loss_scale=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment