Unverified Commit 98cb4ead authored by Alex Lin's avatar Alex Lin Committed by GitHub
Browse files

Replace get_tmp_dir() with tmpdir fixture in tests (#4280)



* Replace in test_datasets*

* Replace in test_image.py

* Replace in test_transforms_tensor.py

* Replace in test_internet.py and test_io.py

* get_list_of_videos is util function still use get_tmp_dir

* Fix get_list_of_videos siginiture

* Add get_tmp_dir import

* Modify test_datasets_video_utils.py for test to pass

* Fix indentation

* Replace get_tmp_dir in util functions in test_dataset_sampler.py

* Replace get_tmp_dir in util functions in test_dataset_video_utils.py

* Move get_tmp_dir() to datasets_utils.py and refactor

* Fix pylint, indentation and imports

* import shutil to common_util.py

* Fix function signiture

* Remove get_list_of_videos under context manager

* Move get_list_of_videos to common_utils.py

* Move get_tmp_dir() back to common_utils.py

* Fix pylint and imports
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f3aff2fa
...@@ -15,6 +15,7 @@ import functools ...@@ -15,6 +15,7 @@ import functools
from numbers import Number from numbers import Number
from torch._six import string_classes from torch._six import string_classes
from collections import OrderedDict from collections import OrderedDict
from torchvision import io
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu ...@@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
return names
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image) np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2: if np_pil_image.ndim == 2:
......
...@@ -22,8 +22,6 @@ from torchvision.datasets.utils import ( ...@@ -22,8 +22,6 @@ from torchvision.datasets.utils import (
USER_AGENT, USER_AGENT,
) )
from common_utils import get_tmp_dir
def limit_requests_per_time(min_secs_between_requests=2.0): def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {} last_requests = {}
...@@ -166,9 +164,8 @@ def assert_url_is_accessible(url, timeout=5.0): ...@@ -166,9 +164,8 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout) urlopen(request, timeout=timeout)
def assert_file_downloads_correctly(url, md5, timeout=5.0): def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
with get_tmp_dir() as root: file = path.join(tmpdir, path.basename(url))
file = path.join(root, path.basename(url))
with assert_server_response_ok(): with assert_server_response_ok():
with open(file, "wb") as fh: with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT}) request = Request(url, headers={"User-Agent": USER_AGENT})
......
...@@ -13,34 +13,13 @@ from torchvision.datasets.samplers import ( ...@@ -13,34 +13,13 @@ from torchvision.datasets.samplers import (
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend from torchvision import get_video_backend
from common_utils import get_tmp_dir, assert_equal from common_utils import get_list_of_videos, assert_equal
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers: class TestDatasetsSamplers:
def test_random_clip_sampler(self): def test_random_clip_sampler(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3 assert len(sampler) == 3 * 3
...@@ -50,8 +29,8 @@ class TestDatasetsSamplers: ...@@ -50,8 +29,8 @@ class TestDatasetsSamplers:
assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3])) assert_equal(count, torch.tensor([3, 3, 3]))
def test_random_clip_sampler_unequal(self): def test_random_clip_sampler_unequal(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3 assert len(sampler) == 2 + 3 + 3
...@@ -67,8 +46,8 @@ class TestDatasetsSamplers: ...@@ -67,8 +46,8 @@ class TestDatasetsSamplers:
assert_equal(v_idxs, torch.tensor([0, 1])) assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3])) assert_equal(count, torch.tensor([3, 3]))
def test_uniform_clip_sampler(self): def test_uniform_clip_sampler(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3 assert len(sampler) == 3 * 3
...@@ -79,16 +58,16 @@ class TestDatasetsSamplers: ...@@ -79,16 +58,16 @@ class TestDatasetsSamplers:
assert_equal(count, torch.tensor([3, 3, 3])) assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14])) assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
def test_uniform_clip_sampler_insufficient_clips(self): def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3) sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3 assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])) assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
def test_distributed_sampler_and_uniform_clip_sampler(self): def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3) clip_sampler = UniformClipSampler(video_clips, 3)
......
...@@ -12,7 +12,6 @@ import itertools ...@@ -12,7 +12,6 @@ import itertools
import lzma import lzma
import contextlib import contextlib
from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
...@@ -113,7 +112,7 @@ class TestDatasetsUtils: ...@@ -113,7 +112,7 @@ class TestDatasetsUtils:
utils._detect_file_type(file) utils._detect_file_type(file)
@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension): def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
compressed = f"{file}{extension}" compressed = f"{file}{extension}"
...@@ -124,8 +123,7 @@ class TestDatasetsUtils: ...@@ -124,8 +123,7 @@ class TestDatasetsUtils:
return compressed, file, content return compressed, file, content
with get_tmp_dir() as temp_dir: compressed, file, content = create_compressed(tmpdir)
compressed, file, content = create_compressed(temp_dir)
utils._decompress(compressed) utils._decompress(compressed)
...@@ -138,7 +136,7 @@ class TestDatasetsUtils: ...@@ -138,7 +136,7 @@ class TestDatasetsUtils:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
utils._decompress("foo.tar") utils._decompress("foo.tar")
def test_decompress_remove_finished(self): def test_decompress_remove_finished(self, tmpdir):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
compressed = f"{file}.gz" compressed = f"{file}.gz"
...@@ -148,10 +146,9 @@ class TestDatasetsUtils: ...@@ -148,10 +146,9 @@ class TestDatasetsUtils:
return compressed, file, content return compressed, file, content
with get_tmp_dir() as temp_dir: compressed, file, content = create_compressed(tmpdir)
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir, remove_finished=True) utils.extract_archive(compressed, tmpdir, remove_finished=True)
assert not os.path.exists(compressed) assert not os.path.exists(compressed)
...@@ -166,7 +163,7 @@ class TestDatasetsUtils: ...@@ -166,7 +163,7 @@ class TestDatasetsUtils:
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
def test_extract_zip(self): def test_extract_zip(self, tmpdir):
def create_archive(root, content="this is the content"): def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt") file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip") archive = os.path.join(root, "archive.zip")
...@@ -176,10 +173,9 @@ class TestDatasetsUtils: ...@@ -176,10 +173,9 @@ class TestDatasetsUtils:
return archive, file, content return archive, file, content
with get_tmp_dir() as temp_dir: archive, file, content = create_archive(tmpdir)
archive, file, content = create_archive(temp_dir)
utils.extract_archive(archive, temp_dir) utils.extract_archive(archive, tmpdir)
assert os.path.exists(file) assert os.path.exists(file)
...@@ -188,7 +184,7 @@ class TestDatasetsUtils: ...@@ -188,7 +184,7 @@ class TestDatasetsUtils:
@pytest.mark.parametrize('extension, mode', [ @pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode): def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"): def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt") src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt") dst = os.path.join(root, "dst.txt")
...@@ -202,10 +198,9 @@ class TestDatasetsUtils: ...@@ -202,10 +198,9 @@ class TestDatasetsUtils:
return archive, dst, content return archive, dst, content
with get_tmp_dir() as temp_dir: archive, file, content = create_archive(tmpdir, extension, mode)
archive, file, content = create_archive(temp_dir, extension, mode)
utils.extract_archive(archive, temp_dir) utils.extract_archive(archive, tmpdir)
assert os.path.exists(file) assert os.path.exists(file)
......
...@@ -6,28 +6,7 @@ import pytest ...@@ -6,28 +6,7 @@ import pytest
from torchvision import io from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_tmp_dir, assert_equal from common_utils import get_list_of_videos, assert_equal
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
class TestVideo: class TestVideo:
...@@ -58,8 +37,8 @@ class TestVideo: ...@@ -58,8 +37,8 @@ class TestVideo:
assert_equal(r, expected) assert_equal(r, expected)
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self): def test_video_clips(self, tmpdir):
with get_list_of_videos(num_videos=3) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3)
video_clips = VideoClips(video_list, 5, 5, num_workers=2) video_clips = VideoClips(video_list, 5, 5, num_workers=2)
assert video_clips.num_clips() == 1 + 2 + 3 assert video_clips.num_clips() == 1 + 2 + 3
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
...@@ -82,8 +61,8 @@ class TestVideo: ...@@ -82,8 +61,8 @@ class TestVideo:
assert clip_idx == c_idx assert clip_idx == c_idx
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips_custom_fps(self): def test_video_clips_custom_fps(self, tmpdir):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6])
num_frames = 4 num_frames = 4
for fps in [1, 3, 4, 10]: for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2) video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2)
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from PIL import Image, __version__ as PILLOW_VERSION from PIL import Image, __version__ as PILLOW_VERSION
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda, assert_equal from common_utils import needs_cuda, assert_equal
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
...@@ -197,14 +197,13 @@ def test_encode_png_errors(): ...@@ -197,14 +197,13 @@ def test_encode_png_errors():
pytest.param(png_path, id=_get_safe_image_name(png_path)) pytest.param(png_path, id=_get_safe_image_name(png_path))
for png_path in get_images(IMAGE_DIR, ".png") for png_path in get_images(IMAGE_DIR, ".png")
]) ])
def test_write_png(img_path): def test_write_png(img_path, tmpdir):
with get_tmp_dir() as d:
pil_image = Image.open(img_path) pil_image = Image.open(img_path)
img_pil = torch.from_numpy(np.array(pil_image)) img_pil = torch.from_numpy(np.array(pil_image))
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(d, '{0}_torch.png'.format(filename)) torch_png = os.path.join(tmpdir, '{0}_torch.png'.format(filename))
write_png(img_pil, torch_png, compression_level=6) write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
...@@ -212,10 +211,9 @@ def test_write_png(img_path): ...@@ -212,10 +211,9 @@ def test_write_png(img_path):
assert_equal(img_pil, saved_image) assert_equal(img_pil, saved_image)
def test_read_file(): def test_read_file(tmpdir):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f: with open(fpath, 'wb') as f:
f.write(content) f.write(content)
...@@ -228,10 +226,9 @@ def test_read_file(): ...@@ -228,10 +226,9 @@ def test_read_file():
read_file('tst') read_file('tst')
def test_read_file_non_ascii(): def test_read_file_non_ascii(tmpdir):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(tmpdir, fname)
with open(fpath, 'wb') as f: with open(fpath, 'wb') as f:
f.write(content) f.write(content)
...@@ -241,10 +238,9 @@ def test_read_file_non_ascii(): ...@@ -241,10 +238,9 @@ def test_read_file_non_ascii():
assert_equal(data, expected) assert_equal(data, expected)
def test_write_file(): def test_write_file(tmpdir):
with get_tmp_dir() as d:
fname, content = 'test1.bin', b'TorchVision\211\n' fname, content = 'test1.bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8) content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor) write_file(fpath, content_tensor)
...@@ -254,10 +250,9 @@ def test_write_file(): ...@@ -254,10 +250,9 @@ def test_write_file():
assert content == saved_content assert content == saved_content
def test_write_file_non_ascii(): def test_write_file_non_ascii(tmpdir):
with get_tmp_dir() as d:
fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
fpath = os.path.join(d, fname) fpath = os.path.join(tmpdir, fname)
content_tensor = torch.tensor(list(content), dtype=torch.uint8) content_tensor = torch.tensor(list(content), dtype=torch.uint8)
write_file(fpath, content_tensor) write_file(fpath, content_tensor)
...@@ -272,10 +267,9 @@ def test_write_file_non_ascii(): ...@@ -272,10 +267,9 @@ def test_write_file_non_ascii():
(60, 60), (60, 60),
(105, 105), (105, 105),
]) ])
def test_read_1_bit_png(shape): def test_read_1_bit_png(shape, tmpdir):
np_rng = np.random.RandomState(0) np_rng = np.random.RandomState(0)
with get_tmp_dir() as root: image_path = os.path.join(tmpdir, f'test_{shape}.png')
image_path = os.path.join(root, f'test_{shape}.png')
pixels = np_rng.rand(*shape) > 0.5 pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels) img = Image.fromarray(pixels)
img.save(image_path) img.save(image_path)
...@@ -293,10 +287,9 @@ def test_read_1_bit_png(shape): ...@@ -293,10 +287,9 @@ def test_read_1_bit_png(shape):
ImageReadMode.UNCHANGED, ImageReadMode.UNCHANGED,
ImageReadMode.GRAY, ImageReadMode.GRAY,
]) ])
def test_read_1_bit_png_consistency(shape, mode): def test_read_1_bit_png_consistency(shape, mode, tmpdir):
np_rng = np.random.RandomState(0) np_rng = np.random.RandomState(0)
with get_tmp_dir() as root: image_path = os.path.join(tmpdir, f'test_{shape}.png')
image_path = os.path.join(root, f'test_{shape}.png')
pixels = np_rng.rand(*shape) > 0.5 pixels = np_rng.rand(*shape) > 0.5
img = Image.fromarray(pixels) img = Image.fromarray(pixels)
img.save(image_path) img.save(image_path)
...@@ -427,16 +420,15 @@ def test_encode_jpeg_reference(img_path): ...@@ -427,16 +420,15 @@ def test_encode_jpeg_reference(img_path):
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
]) ])
def test_write_jpeg_reference(img_path): def test_write_jpeg_reference(img_path, tmpdir):
# FIXME: Remove this eventually, see test_encode_jpeg_reference # FIXME: Remove this eventually, see test_encode_jpeg_reference
with get_tmp_dir() as d:
data = read_file(img_path) data = read_file(img_path)
img = decode_jpeg(data) img = decode_jpeg(data)
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join( torch_jpeg = os.path.join(
d, '{0}_torch.jpg'.format(filename)) tmpdir, '{0}_torch.jpg'.format(filename))
pil_jpeg = os.path.join( pil_jpeg = os.path.join(
basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
...@@ -481,14 +473,13 @@ def test_encode_jpeg(img_path): ...@@ -481,14 +473,13 @@ def test_encode_jpeg(img_path):
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
]) ])
def test_write_jpeg(img_path): def test_write_jpeg(img_path, tmpdir):
with get_tmp_dir() as d: tmpdir = Path(tmpdir)
d = Path(d)
img = read_image(img_path) img = read_image(img_path)
pil_img = F.to_pil_image(img) pil_img = F.to_pil_image(img)
torch_jpeg = str(d / 'torch.jpg') torch_jpeg = str(tmpdir / 'torch.jpg')
pil_jpeg = str(d / 'pil.jpg') pil_jpeg = str(tmpdir / 'pil.jpg')
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
pil_img.save(pil_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75)
......
...@@ -11,35 +11,31 @@ import warnings ...@@ -11,35 +11,31 @@ import warnings
from urllib.error import URLError from urllib.error import URLError
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
from common_utils import get_tmp_dir
class TestDatasetUtils: class TestDatasetUtils:
def test_download_url(self): def test_download_url(self, tmpdir):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip" url = "http://github.com/pytorch/vision/archive/master.zip"
try: try:
utils.download_url(url, temp_dir) utils.download_url(url, tmpdir)
assert len(os.listdir(temp_dir)) != 0 assert len(os.listdir(tmpdir)) != 0
except URLError: except URLError:
pytest.skip(f"could not download test file '{url}'") pytest.skip(f"could not download test file '{url}'")
def test_download_url_retry_http(self): def test_download_url_retry_http(self, tmpdir):
with get_tmp_dir() as temp_dir:
url = "https://github.com/pytorch/vision/archive/master.zip" url = "https://github.com/pytorch/vision/archive/master.zip"
try: try:
utils.download_url(url, temp_dir) utils.download_url(url, tmpdir)
assert len(os.listdir(temp_dir)) != 0 assert len(os.listdir(tmpdir)) != 0
except URLError: except URLError:
pytest.skip(f"could not download test file '{url}'") pytest.skip(f"could not download test file '{url}'")
def test_download_url_dont_exist(self): def test_download_url_dont_exist(self, tmpdir):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with pytest.raises(URLError): with pytest.raises(URLError):
utils.download_url(url, temp_dir) utils.download_url(url, tmpdir)
def test_download_url_dispatch_download_from_google_drive(self, mocker): def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
...@@ -47,10 +43,9 @@ class TestDatasetUtils: ...@@ -47,10 +43,9 @@ class TestDatasetUtils:
md5 = "md5" md5 = "md5"
mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive')
with get_tmp_dir() as root: utils.download_url(url, tmpdir, filename, md5)
utils.download_url(url, root, filename, md5)
mocked.assert_called_once_with(id, root, filename, md5) mocked.assert_called_once_with(id, tmpdir, filename, md5)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -9,7 +9,7 @@ from torchvision import get_video_backend ...@@ -9,7 +9,7 @@ from torchvision import get_video_backend
import warnings import warnings
from urllib.error import URLError from urllib.error import URLError
from common_utils import get_tmp_dir, assert_equal from common_utils import assert_equal
try: try:
...@@ -255,11 +255,10 @@ class TestVideo: ...@@ -255,11 +255,10 @@ class TestVideo:
assert_equal(video, data) assert_equal(video, data)
@pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows') @pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows')
def test_write_video_with_audio(self): def test_write_video_with_audio(self, tmpdir):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
with get_tmp_dir() as tmpdir:
out_f_name = os.path.join(tmpdir, "testing.mp4") out_f_name = os.path.join(tmpdir, "testing.mp4")
io.video.write_video( io.video.write_video(
out_f_name, out_f_name,
......
...@@ -230,7 +230,7 @@ def test_crop_pad(size, padding_config, device): ...@@ -230,7 +230,7 @@ def test_crop_pad(size, padding_config, device):
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
def test_center_crop(device): def test_center_crop(device, tmpdir):
fn_kwargs = {"output_size": (4, 5)} fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), } meth_kwargs = {"size": (4, 5), }
_test_op( _test_op(
...@@ -259,8 +259,7 @@ def test_center_crop(device): ...@@ -259,8 +259,7 @@ def test_center_crop(device):
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
scripted_fn(tensor) scripted_fn(tensor)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt"))
scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -309,11 +308,10 @@ def test_x_crop(fn, method, out_length, size, device): ...@@ -309,11 +308,10 @@ def test_x_crop(fn, method, out_length, size, device):
@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"]) @pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"])
def test_x_crop_save(method): def test_x_crop_save(method, tmpdir):
fn = getattr(T, method)(size=[5, ]) fn = getattr(T, method)(size=[5, ])
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method)))
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
class TestResize: class TestResize:
...@@ -349,11 +347,10 @@ class TestResize: ...@@ -349,11 +347,10 @@ class TestResize:
_test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resize_save(self): def test_resize_save(self, tmpdir):
transform = T.Resize(size=[32, ]) transform = T.Resize(size=[32, ])
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_resize.pt"))
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) @pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]])
...@@ -368,11 +365,10 @@ class TestResize: ...@@ -368,11 +365,10 @@ class TestResize:
_test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resized_crop_save(self): def test_resized_crop_save(self, tmpdir):
transform = T.RandomResizedCrop(size=[32, ]) transform = T.RandomResizedCrop(size=[32, ])
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt"))
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
def _test_random_affine_helper(device, **kwargs): def _test_random_affine_helper(device, **kwargs):
...@@ -386,11 +382,10 @@ def _test_random_affine_helper(device, **kwargs): ...@@ -386,11 +382,10 @@ def _test_random_affine_helper(device, **kwargs):
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
def test_random_affine(device): def test_random_affine(device, tmpdir):
transform = T.RandomAffine(degrees=45.0) transform = T.RandomAffine(degrees=45.0)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_random_affine.pt"))
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -447,11 +442,10 @@ def test_random_rotate(device, center, expand, degrees, interpolation, fill): ...@@ -447,11 +442,10 @@ def test_random_rotate(device, center, expand, degrees, interpolation, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_rotate_save(): def test_random_rotate_save(tmpdir):
transform = T.RandomRotation(degrees=45.0) transform = T.RandomRotation(degrees=45.0)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt"))
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -473,11 +467,10 @@ def test_random_perspective(device, distortion_scale, interpolation, fill): ...@@ -473,11 +467,10 @@ def test_random_perspective(device, distortion_scale, interpolation, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_random_perspective_save(): def test_random_perspective_save(tmpdir):
transform = T.RandomPerspective() transform = T.RandomPerspective()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_perspective.pt"))
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -519,11 +512,10 @@ def test_convert_image_dtype(device, in_dtype, out_dtype): ...@@ -519,11 +512,10 @@ def test_convert_image_dtype(device, in_dtype, out_dtype):
_test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
def test_convert_image_dtype_save(): def test_convert_image_dtype_save(tmpdir):
fn = T.ConvertImageDtype(dtype=torch.uint8) fn = T.ConvertImageDtype(dtype=torch.uint8)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt"))
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -541,11 +533,10 @@ def test_autoaugment(device, policy, fill): ...@@ -541,11 +533,10 @@ def test_autoaugment(device, policy, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_autoaugment_save(): def test_autoaugment_save(tmpdir):
transform = T.AutoAugment() transform = T.AutoAugment()
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir: s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
...@@ -567,11 +558,10 @@ def test_random_erasing(device, config): ...@@ -567,11 +558,10 @@ def test_random_erasing(device, config):
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
def test_random_erasing_save(): def test_random_erasing_save(tmpdir):
fn = T.RandomErasing(value=0.2) fn = T.RandomErasing(value=0.2)
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_random_erasing.pt"))
scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))
def test_random_erasing_with_invalid_data(): def test_random_erasing_with_invalid_data():
...@@ -583,7 +573,7 @@ def test_random_erasing_with_invalid_data(): ...@@ -583,7 +573,7 @@ def test_random_erasing_with_invalid_data():
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
def test_normalize(device): def test_normalize(device, tmpdir):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = _create_data(26, 34, device=device) tensor, _ = _create_data(26, 34, device=device)
...@@ -598,12 +588,11 @@ def test_normalize(device): ...@@ -598,12 +588,11 @@ def test_normalize(device):
_test_transform_vs_scripted(fn, scripted_fn, tensor) _test_transform_vs_scripted(fn, scripted_fn, tensor)
_test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
def test_linear_transformation(device): def test_linear_transformation(device, tmpdir):
c, h, w = 3, 24, 32 c, h, w = 3, 24, 32
tensor, _ = _create_data(h, w, channels=c, device=device) tensor, _ = _create_data(h, w, channels=c, device=device)
...@@ -625,8 +614,7 @@ def test_linear_transformation(device): ...@@ -625,8 +614,7 @@ def test_linear_transformation(device):
s_transformed_batch = scripted_fn(batch_tensors) s_transformed_batch = scripted_fn(batch_tensors)
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
with get_tmp_dir() as tmp_dir: scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment