Commit 12856c3b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

SpectralNormalization: Avoid errors in export by only applying update ops when training=True.

PiperOrigin-RevId: 387447417
parent 62487257
...@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -106,16 +106,19 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
def call(self, inputs, *, training=None): def call(self, inputs, *, training=None):
training = self.do_power_iteration if training is None else training training = self.do_power_iteration if training is None else training
u_update_op, v_update_op, w_update_op = self.update_weights( if training:
training=training) u_update_op, v_update_op, w_update_op = self.update_weights(
output = self.layer(inputs) training=training)
w_restore_op = self.restore_weights() output = self.layer(inputs)
w_restore_op = self.restore_weights()
# Register update ops.
self.add_update(u_update_op) # Register update ops.
self.add_update(v_update_op) self.add_update(u_update_op)
self.add_update(w_update_op) self.add_update(v_update_op)
self.add_update(w_restore_op) self.add_update(w_update_op)
self.add_update(w_restore_op)
else:
output = self.layer(inputs)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment