Commit d821ab89 authored by Neal Wu's avatar Neal Wu
Browse files

Add a tf.summary.scalar call for L2 loss in resnet

parent 4b85dab1
...@@ -235,9 +235,11 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -235,9 +235,11 @@ def resnet_model_fn(features, labels, mode, model_class,
loss_filter_fn = loss_filter_fn or exclude_batch_norm loss_filter_fn = loss_filter_fn or exclude_batch_norm
# Add weight decay to the loss. # Add weight decay to the loss.
loss = cross_entropy + weight_decay * tf.add_n( l2_loss = weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables() [tf.nn.l2_loss(v) for v in tf.trainable_variables()
if loss_filter_fn(v.name)]) if loss_filter_fn(v.name)])
tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step() global_step = tf.train.get_or_create_global_step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment