Commit 8b43ab7c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix keras saving bug in multiheadattention.

Remember the input shapes when build_from_signature is called and trigger the build in from_config.

PiperOrigin-RevId: 350795939
parent 3b0d58e2
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
# ============================================================================== # ==============================================================================
"""Tests for Keras-based transformer block layer.""" """Tests for Keras-based transformer block layer."""
import json
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -419,12 +417,11 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -419,12 +417,11 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# Serialize the model config. Pass the serialized data through json to # Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk. # ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config()) serialized_data = model.get_config()
post_string_serialized_data = json.loads(serialized_data)
# Create a new model from the old config, and copy the weights. These models # Create a new model from the old config, and copy the weights. These models
# should have identical outputs. # should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data) new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights()) new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data]) output = new_model.predict([input_data, mask_data])
...@@ -484,14 +481,10 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -484,14 +481,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length)) 2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data]) pre_serialization_output = model.predict([input_data, mask_data])
# Serialize the model config. Pass the serialized data through json to serialized_data = model.get_config()
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
# Create a new model from the old config, and copy the weights. These models # Create a new model from the old config, and copy the weights. These models
# should have identical outputs. # should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data) new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights()) new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data]) output = new_model.predict([input_data, mask_data])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment