Commit 6da769b1 authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix loss scaling

parent 71c6a697
......@@ -235,12 +235,10 @@ def _get_keras_model(params):
from_logits=True,
reduction="sum")
loss_scale_factor = (batch_size) #*
#tf.distribute.get_strategy().num_replicas_in_sync)
keras_model.add_loss(loss_obj(
y_true=label_input,
y_pred=softmax_logits,
sample_weight=valid_pt_mask_input) * 1.0 / loss_scale_factor)
sample_weight=valid_pt_mask_input) * 1.0 / batch_size)
keras_model.summary()
return keras_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment