Commit 4523c446 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix keras saving bug in multiheadattention.

Remember the input shapes when build_from_signature is called and trigger the build in from_config.

PiperOrigin-RevId: 350795939
parent bbe3df01
......@@ -14,8 +14,6 @@
# ==============================================================================
"""Tests for Keras-based transformer block layer."""
import json
import numpy as np
import tensorflow as tf
......@@ -419,12 +417,11 @@ class TransformerLayerTest(keras_parameterized.TestCase):
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data)
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
......@@ -484,14 +481,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
2, size=(batch_size, sequence_length, sequence_length))
pre_serialization_output = model.predict([input_data, mask_data])
# Serialize the model config. Pass the serialized data through json to
# ensure that we can serialize this layer to disk.
serialized_data = json.dumps(model.get_config())
post_string_serialized_data = json.loads(serialized_data)
serialized_data = model.get_config()
# Create a new model from the old config, and copy the weights. These models
# should have identical outputs.
new_model = tf.keras.Model.from_config(post_string_serialized_data)
new_model = tf.keras.Model.from_config(serialized_data)
new_model.set_weights(model.get_weights())
output = new_model.predict([input_data, mask_data])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment