"app/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c4ba1921870f40659d0b12ca00b3bcc6aa6cece9"
Commit f85ab4c8 authored by Taylor Robie's avatar Taylor Robie Committed by Karmel Allison
Browse files

Improve directory treatment in ResNet end-to-end tests. (#3651)

* use proper temp directory for end to end tests.

* add supers to tearDown
parent 646c3f75
...@@ -35,6 +35,10 @@ _NUM_CHANNELS = 3 ...@@ -35,6 +35,10 @@ _NUM_CHANNELS = 3
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def test_dataset_input_fn(self): def test_dataset_input_fn(self):
fake_data = bytearray() fake_data = bytearray()
fake_data.append(7) fake_data.append(7)
...@@ -137,10 +141,16 @@ class BaseTest(tf.test.TestCase): ...@@ -137,10 +141,16 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '1']) integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
)
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '2']) integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -32,6 +32,10 @@ _LABEL_CLASSES = 1001 ...@@ -32,6 +32,10 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def tearDown(self):
super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir())
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
...@@ -244,26 +248,40 @@ class BaseTest(tf.test.TestCase): ...@@ -244,26 +248,40 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '1']) integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
)
def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '2']) integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
)
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(main=imagenet_main.main, integration.run_synthetic(
extra_flags=['-v', '1', '-rs', '18']) main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(main=imagenet_main.main, integration.run_synthetic(
extra_flags=['-v', '2', '-rs', '18']) main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(main=imagenet_main.main, integration.run_synthetic(
extra_flags=['-v', '1', '-rs', '200']) main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(main=imagenet_main.main, integration.run_synthetic(
extra_flags=['-v', '2', '-rs', '200']) main=imagenet_main.main, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200']
)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
......
...@@ -23,10 +23,10 @@ from __future__ import print_function ...@@ -23,10 +23,10 @@ from __future__ import print_function
import os import os
import shutil import shutil
import sys import sys
import time import tempfile
def run_synthetic(main, extra_flags=None): def run_synthetic(main, tmp_root, extra_flags=None):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -35,14 +35,13 @@ def run_synthetic(main, extra_flags=None): ...@@ -35,14 +35,13 @@ def run_synthetic(main, extra_flags=None):
Args: Args:
main: The primary function used to excercise a code path. Generally this main: The primary function used to excercise a code path. Generally this
function is "<MODULE>.main(argv)". function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the the caller of this function. extra_flags: Additional flags passed by the the caller of this function.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
model_dir = "/tmp/model_test_{}".format(hash(time.time())) model_dir = tempfile.mkdtemp(dir=tmp_root)
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_per_eval", "1", "--use_synthetic_data", "--epochs_per_eval", "1", "--use_synthetic_data",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment