test_datasets_utils.py 1.6 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
3
4
5
6
import shutil
import tempfile
import torchvision.datasets.utils as utils
import unittest

7
8
9
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         'assets', 'grace_hopper_517x606.jpg')

Francisco Massa's avatar
Francisco Massa committed
10
11
12

class Tester(unittest.TestCase):

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)

Francisco Massa's avatar
Francisco Massa committed
30
31
32
33
    def test_download_url(self):
        temp_dir = tempfile.mkdtemp()
        url = "http://github.com/pytorch/vision/archive/master.zip"
        utils.download_url(url, temp_dir)
34
        assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
Francisco Massa's avatar
Francisco Massa committed
35
36
37
38
39
40
        shutil.rmtree(temp_dir)

    def test_download_url_retry_http(self):
        temp_dir = tempfile.mkdtemp()
        url = "https://github.com/pytorch/vision/archive/master.zip"
        utils.download_url(url, temp_dir)
41
        assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
Francisco Massa's avatar
Francisco Massa committed
42
43
44
45
46
        shutil.rmtree(temp_dir)


if __name__ == '__main__':
    unittest.main()