Commit 8309ff13 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds test for encoder save()/load()

PiperOrigin-RevId: 329963956
parent 4412001f
...@@ -181,7 +181,6 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -181,7 +181,6 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertTrue(hasattr(test_network, "_embedding_projection")) self.assertTrue(hasattr(test_network, "_embedding_projection"))
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
...@@ -199,23 +198,26 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -199,23 +198,26 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1, output_range=-1,
embedding_width=16) embedding_width=16)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
tf.keras.activations.get(expected_config["activation"])) tf.keras.activations.get(expected_config["activation"]))
expected_config["initializer"] = tf.keras.initializers.serialize( expected_config["initializer"] = tf.keras.initializers.serialize(
tf.keras.initializers.get(expected_config["initializer"])) tf.keras.initializers.get(expected_config["initializer"]))
self.assertEqual(network.get_config(), expected_config) self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config. # Create another network object from the first object's config.
new_network = bert_encoder.BertEncoder.from_config(network.get_config()) new_network = bert_encoder.BertEncoder.from_config(network.get_config())
# Validate that the config can be forced to JSON. # Validate that the config can be forced to JSON.
_ = new_network.to_json() _ = network.to_json()
# If the serialization was successful, the new config should match the old. # If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config()) self.assertAllEqual(network.get_config(), new_network.get_config())
# Tests model saving/loading.
model_path = self.get_temp_dir() + "/model"
network.save(model_path)
_ = tf.keras.models.load_model(model_path)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment