Commit 84c5fcd3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use self.get_temp_dir and simplify the test

PiperOrigin-RevId: 332512380
parent 02275d26
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tests for ExpandCondense tensor network layer.""" """Tests for ExpandCondense tensor network layer."""
import os import os
import shutil
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
...@@ -166,22 +165,12 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -166,22 +165,12 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
# Train the model for 5 epochs # Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32) model.fit(data, self.labels, epochs=5, batch_size=32)
for save_path in ['/test_model', '/test_model.h5']: save_path = os.path.join(self.get_temp_dir(), 'test_model')
# Save model to a SavedModel folder or h5 file, then load model model.save(save_path)
save_path = os.environ['TEST_UNDECLARED_OUTPUTS_DIR'] + save_path loaded_model = tf.keras.models.load_model(save_path)
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
# Clean up SavedModel folder # Compare model predictions and loaded_model predictions
if os.path.isdir(save_path): self.assertAllEqual(model.predict(data), loaded_model.predict(data))
shutil.rmtree(save_path)
# Clean up h5 file
if os.path.exists(save_path):
os.remove(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment