Commit 674506c4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use self.get_temp_dir and simplify the test

PiperOrigin-RevId: 332512380
parent 96ca6e05
......@@ -15,7 +15,6 @@
"""Tests for ExpandCondense tensor network layer."""
import os
import shutil
from absl.testing import parameterized
import numpy as np
......@@ -166,22 +165,12 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
# Train the model for 5 epochs
model.fit(data, self.labels, epochs=5, batch_size=32)
for save_path in ['/test_model', '/test_model.h5']:
# Save model to a SavedModel folder or h5 file, then load model
save_path = os.environ['TEST_UNDECLARED_OUTPUTS_DIR'] + save_path
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
save_path = os.path.join(self.get_temp_dir(), 'test_model')
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
# Clean up SavedModel folder
if os.path.isdir(save_path):
shutil.rmtree(save_path)
# Clean up h5 file
if os.path.exists(save_path):
os.remove(save_path)
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
# Compare model predictions and loaded_model predictions
self.assertAllEqual(model.predict(data), loaded_model.predict(data))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment