Unverified Commit 98cb4ead authored by Alex Lin's avatar Alex Lin Committed by GitHub
Browse files

Replace get_tmp_dir() with tmpdir fixture in tests (#4280)



* Replace in test_datasets*

* Replace in test_image.py

* Replace in test_transforms_tensor.py

* Replace in test_internet.py and test_io.py

* get_list_of_videos is util function still use get_tmp_dir

* Fix get_list_of_videos siginiture

* Add get_tmp_dir import

* Modify test_datasets_video_utils.py for test to pass

* Fix indentation

* Replace get_tmp_dir in util functions in test_dataset_sampler.py

* Replace get_tmp_dir in util functions in test_dataset_video_utils.py

* Move get_tmp_dir() to datasets_utils.py and refactor

* Fix pylint, indentation and imports

* import shutil to common_util.py

* Fix function signiture

* Remove get_list_of_videos under context manager

* Move get_list_of_videos to common_utils.py

* Move get_tmp_dir() back to common_utils.py

* Fix pylint and imports
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f3aff2fa
......@@ -15,6 +15,7 @@ import functools
from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
from torchvision import io
import numpy as np
from PIL import Image
......@@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
return names
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
......
......@@ -22,8 +22,6 @@ from torchvision.datasets.utils import (
USER_AGENT,
)
from common_utils import get_tmp_dir
def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {}
......@@ -166,9 +164,8 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout)
def assert_file_downloads_correctly(url, md5, timeout=5.0):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
file = path.join(tmpdir, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
......
......@@ -13,34 +13,13 @@ from torchvision.datasets.samplers import (
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend
from common_utils import get_tmp_dir, assert_equal
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
from common_utils import get_list_of_videos, assert_equal
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
def test_random_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
def test_random_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
......@@ -50,8 +29,8 @@ class TestDatasetsSamplers:
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
def test_random_clip_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
def test_random_clip_sampler_unequal(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
......@@ -67,8 +46,8 @@ class TestDatasetsSamplers:
assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3]))
def test_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
def test_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
......@@ -79,16 +58,16 @@ class TestDatasetsSamplers:
assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
def test_uniform_clip_sampler_insufficient_clips(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)
......
......@@ -12,7 +12,6 @@ import itertools
import lzma
import contextlib
from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
......@@ -113,7 +112,7 @@ class TestDatasetsUtils:
utils._detect_file_type(file)
@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension):
def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
......@@ -124,8 +123,7 @@ class TestDatasetsUtils:
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)
utils._decompress(compressed)
......@@ -138,7 +136,7 @@ class TestDatasetsUtils:
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")
def test_decompress_remove_finished(self):
def test_decompress_remove_finished(self, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
......@@ -148,10 +146,9 @@ class TestDatasetsUtils:
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)
utils.extract_archive(compressed, temp_dir, remove_finished=True)
utils.extract_archive(compressed, tmpdir, remove_finished=True)
assert not os.path.exists(compressed)
......@@ -166,7 +163,7 @@ class TestDatasetsUtils:
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
def test_extract_zip(self):
def test_extract_zip(self, tmpdir):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
......@@ -176,10 +173,9 @@ class TestDatasetsUtils:
return archive, file, content
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir)
archive, file, content = create_archive(tmpdir)
utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)
assert os.path.exists(file)
......@@ -188,7 +184,7 @@ class TestDatasetsUtils:
@pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
......@@ -202,10 +198,9 @@ class TestDatasetsUtils:
return archive, dst, content
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, extension, mode)
archive, file, content = create_archive(tmpdir, extension, mode)
utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)
assert os.path.exists(file)
......
......@@ -6,28 +6,7 @@ import pytest
from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_tmp_dir, assert_equal
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
from common_utils import get_list_of_videos, assert_equal
class TestVideo:
......@@ -58,8 +37,8 @@ class TestVideo:
assert_equal(r, expected)
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
def test_video_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3)
video_clips = VideoClips(video_list, 5, 5, num_workers=2)
assert video_clips.num_clips() == 1 + 2 + 3
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
......@@ -82,8 +61,8 @@ class TestVideo:
assert clip_idx == c_idx
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
def test_video_clips_custom_fps(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6])
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2)
......
......@@ -9,7 +9,7 @@ import numpy as np
import torch
from PIL import Image, __version__ as PILLOW_VERSION
import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda, assert_equal
from common_utils import needs_cuda, assert_equal
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
......@@ -197,14 +197,13 @@ def test_encode_png_errors():
pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png")
])
def test_write_png(img_path):
with get_tmp_dir() as d:
def test_write_png(img_path, tmpdir):
pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
torch_png = os.path.join(tmpdir, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1)
......@@ -212,10 +211,9 @@ def test_write_png(img_path):
assert_equal(img_pil, saved_image)
def test_read_file():
with get_tmp_dir() as d:
def test_read_file(tmpdir):
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f:
f.write(content)
......@@ -228,10 +226,9 @@ def test_read_file():
read_file('tst')
def test_read_file_non_ascii():
with get_tmp_dir() as d:
def test_read_file_non_ascii(tmpdir):
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f:
f.write(content)
......@@ -241,10 +238,9 @@ def test_read_file_non_ascii():
assert_equal(data, expected)
def test_write_file():
with get_tmp_dir() as d:
def test_write_file(tmpdir):
fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
......@@ -254,10 +250,9 @@ def test_write_file():
assert content == saved_content
def test_write_file_non_ascii():
with get_tmp_dir() as d:
def test_write_file_non_ascii(tmpdir):
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname)
fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor)
......@@ -272,10 +267,9 @@ def test_write_file_non_ascii():
(60, 60),
(105, 105),
])
def test_read_1_bit_png(shape):
def test_read_1_bit_png(shape, tmpdir):
np_rng = np.random.RandomState(0)
with get_tmp_dir() as root:
image_path = os.path.join(root, f'test_{shape}.png')
image_path = os.path.join(tmpdir, f'test_{shape}.png')
pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels)
img.save(image_path)
......@@ -293,10 +287,9 @@ def test_read_1_bit_png(shape):
ImageReadMode.UNCHANGED,
ImageReadMode.GRAY,
])
def test_read_1_bit_png_consistency(shape, mode):
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
np_rng = np.random.RandomState(0)
with get_tmp_dir() as root:
image_path = os.path.join(root, f'test_{shape}.png')
image_path = os.path.join(tmpdir, f'test_{shape}.png')
pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels)
img.save(image_path)
......@@ -427,16 +420,15 @@ def test_encode_jpeg_reference(img_path):
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg_reference(img_path):
def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference
with get_tmp_dir() as d:
data = read_file(img_path)
img = decode_jpeg(data)
basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename))
tmpdir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
......@@ -481,14 +473,13 @@ def test_encode_jpeg(img_path):
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
])
def test_write_jpeg(img_path):
with get_tmp_dir() as d:
d = Path(d)
def test_write_jpeg(img_path, tmpdir):
tmpdir = Path(tmpdir)
img = read_image(img_path)
pil_img = F.to_pil_image(img)
torch_jpeg = str(d / 'torch.jpg')
pil_jpeg = str(d / 'pil.jpg')
torch_jpeg = str(tmpdir / 'torch.jpg')
pil_jpeg = str(tmpdir / 'pil.jpg')
write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75)
......
......@@ -11,35 +11,31 @@ import warnings
from urllib.error import URLError
import torchvision.datasets.utils as utils
from common_utils import get_tmp_dir
class TestDatasetUtils:
def test_download_url(self):
with get_tmp_dir() as temp_dir:
def test_download_url(self, tmpdir):
url = "http://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
assert len(os.listdir(temp_dir)) != 0
utils.download_url(url, tmpdir)
assert len(os.listdir(tmpdir)) != 0
except URLError:
pytest.skip(f"could not download test file '{url}'")
def test_download_url_retry_http(self):
with get_tmp_dir() as temp_dir:
def test_download_url_retry_http(self, tmpdir):
url = "https://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
assert len(os.listdir(temp_dir)) != 0
utils.download_url(url, tmpdir)
assert len(os.listdir(tmpdir)) != 0
except URLError:
pytest.skip(f"could not download test file '{url}'")
def test_download_url_dont_exist(self):
with get_tmp_dir() as temp_dir:
def test_download_url_dont_exist(self, tmpdir):
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with pytest.raises(URLError):
utils.download_url(url, temp_dir)
utils.download_url(url, tmpdir)
def test_download_url_dispatch_download_from_google_drive(self, mocker):
def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
......@@ -47,10 +43,9 @@ class TestDatasetUtils:
md5 = "md5"
mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive')
with get_tmp_dir() as root:
utils.download_url(url, root, filename, md5)
utils.download_url(url, tmpdir, filename, md5)
mocked.assert_called_once_with(id, root, filename, md5)
mocked.assert_called_once_with(id, tmpdir, filename, md5)
if __name__ == '__main__':
......
......@@ -9,7 +9,7 @@ from torchvision import get_video_backend
import warnings
from urllib.error import URLError
from common_utils import get_tmp_dir, assert_equal
from common_utils import assert_equal
try:
......@@ -255,11 +255,10 @@ class TestVideo:
assert_equal(video, data)
@pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows')
def test_write_video_with_audio(self):
def test_write_video_with_audio(self, tmpdir):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
with get_tmp_dir() as tmpdir:
out_f_name = os.path.join(tmpdir, "testing.mp4")
io.video.write_video(
out_f_name,
......
......@@ -230,7 +230,7 @@ def test_crop_pad(size, padding_config, device):
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_center_crop(device):
def test_center_crop(device, tmpdir):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
_test_op(
......@@ -259,8 +259,7 @@ def test_center_crop(device):
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))
scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -309,11 +308,10 @@ def test_x_crop(fn, method, out_length, size, device):
@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"])
def test_x_crop_save(method):
def test_x_crop_save(method, tmpdir):
fn = getattr(T, method)(size=[5, ])
scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method)))
class TestResize:
......@@ -349,11 +347,10 @@ class TestResize:
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resize_save(self):
def test_resize_save(self, tmpdir):
transform = T.Resize(size=[32, ])
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
s_transform.save(os.path.join(tmpdir, "t_resize.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]])
......@@ -368,11 +365,10 @@ class TestResize:
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resized_crop_save(self):
def test_resized_crop_save(self, tmpdir):
transform = T.RandomResizedCrop(size=[32, ])
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt"))
def _test_random_affine_helper(device, **kwargs):
......@@ -386,11 +382,10 @@ def _test_random_affine_helper(device, **kwargs):
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_random_affine(device):
def test_random_affine(device, tmpdir):
transform = T.RandomAffine(degrees=45.0)
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
s_transform.save(os.path.join(tmpdir, "t_random_affine.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -447,11 +442,10 @@ def test_random_rotate(device, center, expand, degrees, interpolation, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_rotate_save():
def test_random_rotate_save(tmpdir):
transform = T.RandomRotation(degrees=45.0)
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -473,11 +467,10 @@ def test_random_perspective(device, distortion_scale, interpolation, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_perspective_save():
def test_random_perspective_save(tmpdir):
transform = T.RandomPerspective()
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
s_transform.save(os.path.join(tmpdir, "t_perspective.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -519,11 +512,10 @@ def test_convert_image_dtype(device, in_dtype, out_dtype):
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
def test_convert_image_dtype_save():
def test_convert_image_dtype_save(tmpdir):
fn = T.ConvertImageDtype(dtype=torch.uint8)
scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -541,11 +533,10 @@ def test_autoaugment(device, policy, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_autoaugment_save():
def test_autoaugment_save(tmpdir):
transform = T.AutoAugment()
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......@@ -567,11 +558,10 @@ def test_random_erasing(device, config):
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
def test_random_erasing_save():
def test_random_erasing_save(tmpdir):
fn = T.RandomErasing(value=0.2)
scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
scripted_fn.save(os.path.join(tmpdir, "t_random_erasing.pt"))
def test_random_erasing_with_invalid_data():
......@@ -583,7 +573,7 @@ def test_random_erasing_with_invalid_data():
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_normalize(device):
def test_normalize(device, tmpdir):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = _create_data(26, 34, device=device)
......@@ -598,12 +588,11 @@ def test_normalize(device):
_test_transform_vs_scripted(fn, scripted_fn, tensor)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
def test_linear_transformation(device):
def test_linear_transformation(device, tmpdir):
c, h, w = 3, 24, 32
tensor, _ = _create_data(h, w, channels=c, device=device)
......@@ -625,8 +614,7 @@ def test_linear_transformation(device):
s_transformed_batch = scripted_fn(batch_tensors)
assert_equal(transformed_batch, s_transformed_batch)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment