Commit 306e3e14 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Do not scale loss manually for BERT classifier when compile/fit() API is used.

PiperOrigin-RevId: 275142626
parent 06412123
......@@ -132,10 +132,17 @@ def run_bert_classifier(strategy,
classifier_model.optimizer)
return classifier_model, core_model
loss_fn = get_loss_fn(
num_classes,
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# During distributed training, loss used for gradient computation is
# summed over from all replicas. When Keras compile/fit() API is used,
# the fit() API internally normalizes the loss by dividing the loss by
# the number of replicas used for computation. However, when custom
# training loop is used this is not done automatically and should be
# done manually by the end user.
loss_multiplier = 1.0
if FLAGS.scale_loss and not use_keras_compile_fit:
loss_multiplier = 1.0 / strategy.num_replicas_in_sync
loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment