import os import sys import tempfile import torchvision.datasets.utils as utils import unittest import unittest.mock import zipfile import tarfile import gzip import warnings from torch._utils_internal import get_file_path_2 from urllib.error import URLError from common_utils import get_tmp_dir TEST_FILE = get_file_path_2( os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') class Tester(unittest.TestCase): def test_check_md5(self): fpath = TEST_FILE correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' false_md5 = '' self.assertTrue(utils.check_md5(fpath, correct_md5)) self.assertFalse(utils.check_md5(fpath, false_md5)) def test_check_integrity(self): existing_fpath = TEST_FILE nonexisting_fpath = '' correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' false_md5 = '' self.assertTrue(utils.check_integrity(existing_fpath, correct_md5)) self.assertFalse(utils.check_integrity(existing_fpath, false_md5)) self.assertTrue(utils.check_integrity(existing_fpath)) self.assertFalse(utils.check_integrity(nonexisting_fpath)) def test_get_redirect_url(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" actual = utils._get_redirect_url(url) assert actual == expected def test_get_redirect_url_max_hops_exceeded(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" with self.assertRaises(RecursionError): utils._get_redirect_url(url, max_hops=0) def test_get_google_drive_file_id(self): url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" actual = utils._get_google_drive_file_id(url) assert actual == expected def test_get_google_drive_file_id_invalid_url(self): url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" assert utils._get_google_drive_file_id(url) is None def test_download_url(self): with get_tmp_dir() as temp_dir: url = "http://github.com/pytorch/vision/archive/master.zip" try: utils.download_url(url, temp_dir) self.assertFalse(len(os.listdir(temp_dir)) == 0) except URLError: msg = "could not download test file '{}'".format(url) warnings.warn(msg, RuntimeWarning) raise unittest.SkipTest(msg) def test_download_url_retry_http(self): with get_tmp_dir() as temp_dir: url = "https://github.com/pytorch/vision/archive/master.zip" try: utils.download_url(url, temp_dir) self.assertFalse(len(os.listdir(temp_dir)) == 0) except URLError: msg = "could not download test file '{}'".format(url) warnings.warn(msg, RuntimeWarning) raise unittest.SkipTest(msg) def test_download_url_dont_exist(self): with get_tmp_dir() as temp_dir: url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip" with self.assertRaises(URLError): utils.download_url(url, temp_dir) @unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive") def test_download_url_dispatch_download_from_google_drive(self, mock): url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view" id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" filename = "filename" md5 = "md5" with get_tmp_dir() as root: utils.download_url(url, root, filename, md5) mock.assert_called_once_with(id, root, filename, md5) @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_zip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.zip') as f: with zipfile.ZipFile(f, 'w') as zf: zf.writestr('file.tst', 'this is the content') utils.extract_archive(f.name, temp_dir) self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: data = nf.read() self.assertEqual(data, 'this is the content') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar(self): for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile() as bf: bf.write("this is the content".encode()) bf.seek(0) with tempfile.NamedTemporaryFile(suffix=ext) as f: with tarfile.open(f.name, mode=mode) as zf: zf.add(bf.name, arcname='file.tst') utils.extract_archive(f.name, temp_dir) self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: data = nf.read() self.assertEqual(data, 'this is the content') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_tar_xz(self): for ext, mode in zip(['.tar.xz'], ['w:xz']): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile() as bf: bf.write("this is the content".encode()) bf.seek(0) with tempfile.NamedTemporaryFile(suffix=ext) as f: with tarfile.open(f.name, mode=mode) as zf: zf.add(bf.name, arcname='file.tst') utils.extract_archive(f.name, temp_dir) self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst'))) with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: data = nf.read() self.assertEqual(data, 'this is the content') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_extract_gzip(self): with get_tmp_dir() as temp_dir: with tempfile.NamedTemporaryFile(suffix='.gz') as f: with gzip.GzipFile(f.name, 'wb') as zf: zf.write('this is the content'.encode()) utils.extract_archive(f.name, temp_dir) f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) self.assertTrue(os.path.exists(f_name)) with open(os.path.join(f_name), 'r') as nf: data = nf.read() self.assertEqual(data, 'this is the content') def test_verify_str_arg(self): self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") if __name__ == '__main__': unittest.main()