Unverified Commit b3b51377 authored by Anirudh's avatar Anirudh Committed by GitHub
Browse files

Port test_datasets_utils to pytest (#4114)

parent ab60e538
...@@ -240,23 +240,6 @@ def disable_console_output(): ...@@ -240,23 +240,6 @@ def disable_console_output():
yield yield
def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
callable_or_arg_name = callable_or_arg_names[0]
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name)
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names
args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only
def cpu_and_gpu(): def cpu_and_gpu():
import pytest # noqa import pytest # noqa
return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda)) return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda))
......
import bz2 import bz2
import os import os
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
import unittest import pytest
import unittest.mock
import zipfile import zipfile
import tarfile import tarfile
import gzip import gzip
...@@ -12,31 +11,32 @@ from urllib.error import URLError ...@@ -12,31 +11,32 @@ from urllib.error import URLError
import itertools import itertools
import lzma import lzma
from common_utils import get_tmp_dir, call_args_to_kwargs_only from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
TEST_FILE = get_file_path_2( TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class TestDatasetsUtils:
def test_check_md5(self): def test_check_md5(self):
fpath = TEST_FILE fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = '' false_md5 = ''
self.assertTrue(utils.check_md5(fpath, correct_md5)) assert utils.check_md5(fpath, correct_md5)
self.assertFalse(utils.check_md5(fpath, false_md5)) assert not utils.check_md5(fpath, false_md5)
def test_check_integrity(self): def test_check_integrity(self):
existing_fpath = TEST_FILE existing_fpath = TEST_FILE
nonexisting_fpath = '' nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = '' false_md5 = ''
self.assertTrue(utils.check_integrity(existing_fpath, correct_md5)) assert utils.check_integrity(existing_fpath, correct_md5)
self.assertFalse(utils.check_integrity(existing_fpath, false_md5)) assert not utils.check_integrity(existing_fpath, false_md5)
self.assertTrue(utils.check_integrity(existing_fpath)) assert utils.check_integrity(existing_fpath)
self.assertFalse(utils.check_integrity(nonexisting_fpath)) assert not utils.check_integrity(nonexisting_fpath)
def test_get_google_drive_file_id(self): def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
...@@ -50,8 +50,7 @@ class Tester(unittest.TestCase): ...@@ -50,8 +50,7 @@ class Tester(unittest.TestCase):
assert utils._get_google_drive_file_id(url) is None assert utils._get_google_drive_file_id(url) is None
def test_detect_file_type(self): @pytest.mark.parametrize('file, expected', [
for file, expected in [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")), ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)), ("foo.tar", (".tar", ".tar", None)),
...@@ -65,29 +64,24 @@ class Tester(unittest.TestCase): ...@@ -65,29 +64,24 @@ class Tester(unittest.TestCase):
("foo.xz", (".xz", None, ".xz")), ("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")), ("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)), ("foo.bar.zip", (".zip", ".zip", None))])
]: def test_detect_file_type(self, file, expected):
with self.subTest(file=file): assert utils._detect_file_type(file) == expected
self.assertSequenceEqual(utils._detect_file_type(file), expected)
@pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"])
def test_detect_file_type_no_ext(self): def test_detect_file_type_incompatible(self, file):
with self.assertRaises(RuntimeError): # tests detect file type for no extension, unknown compression and unknown partial extension
utils._detect_file_type("foo") with pytest.raises(RuntimeError):
utils._detect_file_type(file)
def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError): @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
utils._detect_file_type("foo.tar.baz") def test_decompress(self, extension):
def test_detect_file_type_unknown_partial_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")
def test_decompress_bz2(self):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
compressed = f"{file}.bz2" compressed = f"{file}{extension}"
compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
with bz2.open(compressed, "wb") as fh: with compressed_file_opener(compressed, "wb") as fh:
fh.write(content.encode()) fh.write(content.encode())
return compressed, file, content return compressed, file, content
...@@ -97,53 +91,13 @@ class Tester(unittest.TestCase): ...@@ -97,53 +91,13 @@ class Tester(unittest.TestCase):
utils._decompress(compressed) utils._decompress(compressed)
self.assertTrue(os.path.exists(file)) assert os.path.exists(file)
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils._decompress(compressed)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_lzma(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.xz"
with lzma.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh: with open(file, "r") as fh:
self.assertEqual(fh.read(), content) assert fh.read() == content
def test_decompress_no_compression(self): def test_decompress_no_compression(self):
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError):
utils._decompress("foo.tar") utils._decompress("foo.tar")
def test_decompress_remove_finished(self): def test_decompress_remove_finished(self):
...@@ -161,21 +115,18 @@ class Tester(unittest.TestCase): ...@@ -161,21 +115,18 @@ class Tester(unittest.TestCase):
utils.extract_archive(compressed, temp_dir, remove_finished=True) utils.extract_archive(compressed, temp_dir, remove_finished=True)
self.assertFalse(os.path.exists(compressed)) assert not os.path.exists(compressed)
def test_extract_archive_defer_to_decompress(self): @pytest.mark.parametrize('extension', [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False])
def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
filename = "foo" filename = "foo"
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)): file = f"{filename}{extension}"
with self.subTest(ext=ext, remove_finished=remove_finished):
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock: mocked = mocker.patch("torchvision.datasets.utils._decompress")
file = f"{filename}{ext}"
utils.extract_archive(file, remove_finished=remove_finished) utils.extract_archive(file, remove_finished=remove_finished)
mock.assert_called_once() mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
self.assertEqual(
call_args_to_kwargs_only(mock.call_args, utils._decompress),
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
)
def test_extract_zip(self): def test_extract_zip(self):
def create_archive(root, content="this is the content"): def create_archive(root, content="this is the content"):
...@@ -192,41 +143,18 @@ class Tester(unittest.TestCase): ...@@ -192,41 +143,18 @@ class Tester(unittest.TestCase):
utils.extract_archive(archive, temp_dir) utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file)) assert os.path.exists(file)
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_extract_tar(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")
with open(src, "w") as fh:
fh.write(content)
with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))
return archive, dst, content
for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode)
utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh: with open(file, "r") as fh:
self.assertEqual(fh.read(), content) assert fh.read() == content
def test_extract_tar_xz(self): @pytest.mark.parametrize('extension, mode', [
def create_archive(root, ext, mode, content="this is the content"): ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt") src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt") dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}") archive = os.path.join(root, f"archive{extension}")
with open(src, "w") as fh: with open(src, "w") as fh:
fh.write(content) fh.write(content)
...@@ -236,22 +164,21 @@ class Tester(unittest.TestCase): ...@@ -236,22 +164,21 @@ class Tester(unittest.TestCase):
return archive, dst, content return archive, dst, content
for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir: with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, ext, mode) archive, file, content = create_archive(temp_dir, extension, mode)
utils.extract_archive(archive, temp_dir) utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file)) assert os.path.exists(file)
with open(file, "r") as fh: with open(file, "r") as fh:
self.assertEqual(fh.read(), content) assert fh.read() == content
def test_verify_str_arg(self): def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) assert "a" == utils.verify_str_arg("a", "arg", ("a",))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment