Commit c0bdb378 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Using better version of l2 loss to avoid reshape, concat and split ops.

Also adding support for CTL mode in the borg file.

PiperOrigin-RevId: 284075404
parent f079ed2e
......@@ -274,13 +274,12 @@ def run(flags_obj):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
])
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment