Unverified Commit 589ac399 authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Manually scale the loss in Resnet DS model (#6195)

* Manually scale the loss in Resnet DS model

* Update resnet_run_loop.py
parent e334f3e2
...@@ -357,11 +357,16 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -357,11 +357,16 @@ def resnet_model_fn(features, labels, mode, model_class,
return 'batch_normalization' not in name return 'batch_normalization' not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss. # Add weight decay to the loss. We need to scale the regularization loss
# manually as losses other than in tf.losses and tf.keras.losses don't scale
# automatically.
l2_loss = weight_decay * tf.add_n( l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability. # loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) [
for v in tf.compat.v1.trainable_variables() if loss_filter_fn(v.name)]) tf.nn.l2_loss(tf.cast(v, tf.float32))
for v in tf.trainable_variables()
if loss_filter_fn(v.name)
]) / tf.distribute.get_strategy().num_replicas_in_sync
tf.compat.v1.summary.scalar('l2_loss', l2_loss) tf.compat.v1.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss loss = cross_entropy + l2_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment