Commit 5b81bb59 authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

scale loss by num replicas

parent 346b570f
......@@ -235,10 +235,11 @@ def _get_keras_model(params):
from_logits=True,
reduction="sum")
loss_scale_factor = batch_size * tf.distribute.get_strategy().num_replicas_in_sync
keras_model.add_loss(loss_obj(
y_true=label_input,
y_pred=softmax_logits,
sample_weight=valid_pt_mask_input) * 1.0 / batch_size)
sample_weight=valid_pt_mask_input) * 1.0 / loss_scale_factor)
keras_model.summary()
return keras_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment