Unverified Commit 1a46ec94 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix test_extract_(zip|tar|tar_xz|gzip) on windows (#3542)



* fix test_extract_(zip|tar|tar_xz|gzip) on windows

* lint
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent e38b9ee4
import os
import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import unittest.mock
......@@ -102,62 +100,95 @@ class Tester(unittest.TestCase):
mock.assert_called_once_with(id, root, filename, md5)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr(os.path.basename(file), content)
return archive, file, content
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir)
utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_extract_tar(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")
with open(src, "w") as fh:
fh.write(content)
with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))
return archive, dst, content
for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir, ext, mode)
utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_extract_tar_xz(self):
def create_archive(root, ext, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{ext}")
with open(src, "w") as fh:
fh.write(content)
with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))
return archive, dst, content
for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
archive, file, content = create_archive(temp_dir, ext, mode)
utils.extract_archive(archive, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_extract_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
with gzip.GzipFile(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
self.assertTrue(os.path.exists(f_name))
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment