Commit 12856c3b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

SpectralNormalization: Avoid errors in export by only applying update ops when training=True.

PiperOrigin-RevId: 387447417
parent 62487257
......@@ -106,6 +106,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training
if training:
u_update_op, v_update_op, w_update_op = self.update_weights(
training=training)
output = self.layer(inputs)
......@@ -116,6 +117,8 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.add_update(v_update_op)
self.add_update(w_update_op)
self.add_update(w_restore_op)
else:
output = self.layer(inputs)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment