"...git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "43fde05a69a4eeab1c116b849aaf9017cad4e46f"
Unverified Commit c33d3ef4 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Update logic to rescale L2 loss in distribution strategy (#6601)

parent 2ae6d37a
......@@ -419,16 +419,14 @@ def resnet_model_fn(features, labels, mode, model_class,
return 'batch_normalization' not in name
loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss. We need to scale the regularization loss
# manually as losses other than in tf.losses and tf.keras.losses don't scale
# automatically.
# Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n(
# loss is computed using fp32 for numerical stability.
[
tf.nn.l2_loss(tf.cast(v, tf.float32))
for v in tf.compat.v1.trainable_variables()
if loss_filter_fn(v.name)
]) / tf.distribute.get_strategy().num_replicas_in_sync
])
tf.compat.v1.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment