Commit 306e3e14 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Do not scale loss manually for BERT classifier when compile/fit() API is used.

PiperOrigin-RevId: 275142626
parent 06412123
...@@ -132,10 +132,17 @@ def run_bert_classifier(strategy, ...@@ -132,10 +132,17 @@ def run_bert_classifier(strategy,
classifier_model.optimizer) classifier_model.optimizer)
return classifier_model, core_model return classifier_model, core_model
loss_fn = get_loss_fn( # During distributed training, loss used for gradient computation is
num_classes, # summed over from all replicas. When Keras compile/fit() API is used,
loss_factor=1.0 / # the fit() API internally normalizes the loss by dividing the loss by
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0) # the number of replicas used for computation. However, when custom
# training loop is used this is not done automatically and should be
# done manually by the end user.
loss_multiplier = 1.0
if FLAGS.scale_loss and not use_keras_compile_fit:
loss_multiplier = 1.0 / strategy.num_replicas_in_sync
loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier)
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment