"experiments/recognition/resnet50_baseline.sh" did not exist on "17be9e1607d81936e51c7fde112086cd31d2588e"
Commit 0f6ff657 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

adding fp16 implementation for BERT

parent bd211e3e
......@@ -65,7 +65,8 @@ def define_common_bert_flags():
loss_scale=True,
all_reduce_alg=False,
num_packs=False,
enable_xla=True
enable_xla=True,
fp16_implementation=True,
)
......
......@@ -127,6 +127,14 @@ def run_customized_training(strategy,
bert_config, max_seq_length, max_predictions_per_seq)
pretrain_model.optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
if FLAGS.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
pretrain_model.optimizer)
return pretrain_model, core_model
trained_model = model_training_utils.run_customized_training_loop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment