test_datasets_utils.py 5.78 KB
Newer Older
1
import os
2
import sys
Francisco Massa's avatar
Francisco Massa committed
3
4
5
import tempfile
import torchvision.datasets.utils as utils
import unittest
6
7
8
import zipfile
import tarfile
import gzip
9
10
import warnings
from torch._utils_internal import get_file_path_2
11
from urllib.error import URLError
Francisco Massa's avatar
Francisco Massa committed
12

13
14
15
16
from common_utils import get_tmp_dir


TEST_FILE = get_file_path_2(
17
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
18

Francisco Massa's avatar
Francisco Massa committed
19
20
21

class Tester(unittest.TestCase):

22
23
24
25
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
26
27
        self.assertTrue(utils.check_md5(fpath, correct_md5))
        self.assertFalse(utils.check_md5(fpath, false_md5))
28
29
30
31
32
33

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
34
35
36
37
        self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
        self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
        self.assertTrue(utils.check_integrity(existing_fpath))
        self.assertFalse(utils.check_integrity(nonexisting_fpath))
38

Francisco Massa's avatar
Francisco Massa committed
39
    def test_download_url(self):
40
41
42
43
44
45
46
47
48
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
49
50

    def test_download_url_retry_http(self):
51
52
53
54
55
56
57
58
59
        with get_tmp_dir() as temp_dir:
            url = "https://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
60

61
62
63
64
65
66
    def test_download_url_dont_exist(self):
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
            with self.assertRaises(URLError):
                utils.download_url(url, temp_dir)

Francisco Massa's avatar
Francisco Massa committed
67
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
68
    def test_extract_zip(self):
69
70
71
72
73
74
75
76
77
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.zip') as f:
                with zipfile.ZipFile(f, 'w') as zf:
                    zf.writestr('file.tst', 'this is the content')
                utils.extract_archive(f.name, temp_dir)
                self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
78

Francisco Massa's avatar
Francisco Massa committed
79
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
80
    def test_extract_tar(self):
81
        for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
82
83
84
85
86
87
88
89
90
91
92
93
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
Ardalan's avatar
Ardalan committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
    def test_extract_tar_xz(self):
        for ext, mode in zip(['.tar.xz'], ['w:xz']):
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
110

Francisco Massa's avatar
Francisco Massa committed
111
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
112
    def test_extract_gzip(self):
113
114
115
116
117
118
119
120
121
122
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.gz') as f:
                with gzip.GzipFile(f.name, 'wb') as zf:
                    zf.write('this is the content'.encode())
                utils.extract_archive(f.name, temp_dir)
                f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
                self.assertTrue(os.path.exists(f_name))
                with open(os.path.join(f_name), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
123

124
125
126
127
128
    def test_verify_str_arg(self):
        self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
        self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")

Francisco Massa's avatar
Francisco Massa committed
129
130
131

if __name__ == '__main__':
    unittest.main()