Commit b37c3fc1 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Recover global policy between tests.

PiperOrigin-RevId: 286139970
parent 8f4bf01f
......@@ -53,10 +53,13 @@ class CtlImagenetTest(tf.test.TestCase):
def setUp(self):
super(CtlImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
self.policy = \
tf.compat.v2.keras.mixed_precision.experimental.global_policy()
def tearDown(self):
super(CtlImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.policy)
def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
......
......@@ -45,10 +45,13 @@ class KerasImagenetTest(tf.test.TestCase):
def setUp(self):
super(KerasImagenetTest, self).setUp()
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
self.policy = \
tf.compat.v2.keras.mixed_precision.experimental.global_policy()
def tearDown(self):
super(KerasImagenetTest, self).tearDown()
tf.io.gfile.rmtree(self.get_temp_dir())
tf.compat.v2.keras.mixed_precision.experimental.set_policy(self.policy)
def test_end_to_end_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment