Commit f85ab4c8 authored by Taylor Robie's avatar Taylor Robie Committed by Karmel Allison
Browse files

Improve directory treatment in ResNet end-to-end tests. (#3651)

* use proper temp directory for end to end tests.

* add supers to tearDown
parent 646c3f75
......@@ -35,6 +35,10 @@ _NUM_CHANNELS = 3
class BaseTest(tf.test.TestCase):
def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_dataset_input_fn(self):
fake_data = bytearray()
fake_data.append(7)
......@@ -137,10 +141,16 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '1'])
integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
)
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '2'])
integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
)
if __name__ == '__main__':
......
......@@ -32,6 +32,10 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase):
def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
"""Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape):
......@@ -244,26 +248,40 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '1'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
)
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '2'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
)
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '18'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '18'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '200'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '200'])
integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200']
)
if __name__ == '__main__':
tf.test.main()
......
......@@ -23,10 +23,10 @@ from __future__ import print_function
import os
import shutil
import sys
import time
import tempfile
def run_synthetic(main, extra_flags=None):
def run_synthetic(main, tmp_root, extra_flags=None):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
......@@ -35,14 +35,13 @@ def run_synthetic(main, extra_flags=None):
Args:
main: The primary function used to excercise a code path. Generally this
function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the the caller of this function.
"""
extra_flags = [] if extra_flags is None else extra_flags
model_dir = "/tmp/model_test_{}".format(hash(time.time()))
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_per_eval", "1", "--use_synthetic_data",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment