Unverified Commit f8a9957c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Separate extraction and decompression logic in datasets.utils.extract_archive (#3443)



* generalize extract_archive

* [test] re-enable extraction tests on windows

* add tests for detect_file_type

* add error messages to detect_file_type

* Revert "[test] re-enable extraction tests on windows"

This reverts commit 7fafebb0f6b4c49bd72c4b5e0a0b4b8c432bce57.

* add utility functions for better mock call checking

* add tests for decompress

* simplify logic by using pathlib

* lint

* Apply suggestions from code review
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* make decompress private

* remove unnecessary checks

* add error message

* fix mocking

* add remaining tests

* lint
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 4560556d
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import warnings import warnings
import __main__ import __main__
import random import random
import inspect
from numbers import Number from numbers import Number
from torch._six import string_classes from torch._six import string_classes
...@@ -401,3 +402,20 @@ def disable_console_output(): ...@@ -401,3 +402,20 @@ def disable_console_output():
stack.enter_context(contextlib.redirect_stdout(devnull)) stack.enter_context(contextlib.redirect_stdout(devnull))
stack.enter_context(contextlib.redirect_stderr(devnull)) stack.enter_context(contextlib.redirect_stderr(devnull))
yield yield
def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
callable_or_arg_name = callable_or_arg_names[0]
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name)
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names
args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only
...@@ -8,8 +8,10 @@ import gzip ...@@ -8,8 +8,10 @@ import gzip
import warnings import warnings
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
from urllib.error import URLError from urllib.error import URLError
import itertools
import lzma
from common_utils import get_tmp_dir from common_utils import get_tmp_dir, call_args_to_kwargs_only
TEST_FILE = get_file_path_2( TEST_FILE = get_file_path_2(
...@@ -100,6 +102,114 @@ class Tester(unittest.TestCase): ...@@ -100,6 +102,114 @@ class Tester(unittest.TestCase):
mock.assert_called_once_with(id, root, filename, md5) mock.assert_called_once_with(id, root, filename, md5)
def test_detect_file_type(self):
for file, expected in [
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)
def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")
def test_detect_file_type_to_many_exts(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.tar.gz")
def test_detect_file_type_unknown_archive_type(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.gz")
def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")
def test_detect_file_type_unknown_partial_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")
def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils._decompress(compressed)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_lzma(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.xz"
with lzma.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_no_compression(self):
with self.assertRaises(RuntimeError):
utils._decompress("foo.tar")
def test_decompress_remove_finished(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir, remove_finished=True)
self.assertFalse(os.path.exists(compressed))
def test_extract_archive_defer_to_decompress(self):
filename = "foo"
for ext, remove_finished in itertools.product((".gz", ".xz"), (True, False)):
with self.subTest(ext=ext, remove_finished=remove_finished):
with unittest.mock.patch("torchvision.datasets.utils._decompress") as mock:
file = f"{filename}{ext}"
utils.extract_archive(file, remove_finished=remove_finished)
mock.assert_called_once()
self.assertEqual(
call_args_to_kwargs_only(mock.call_args, utils._decompress),
dict(from_path=file, to_path=filename, remove_finished=remove_finished),
)
def test_extract_zip(self): def test_extract_zip(self):
def create_archive(root, content="this is the content"): def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt") file = os.path.join(root, "dst.txt")
...@@ -170,26 +280,6 @@ class Tester(unittest.TestCase): ...@@ -170,26 +280,6 @@ class Tester(unittest.TestCase):
with open(file, "r") as fh: with open(file, "r") as fh:
self.assertEqual(fh.read(), content) self.assertEqual(fh.read(), content)
def test_extract_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
with gzip.GzipFile(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils.extract_archive(compressed, temp_dir)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_verify_str_arg(self): def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",))) self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
......
...@@ -4,12 +4,15 @@ import hashlib ...@@ -4,12 +4,15 @@ import hashlib
import gzip import gzip
import re import re
import tarfile import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
import zipfile import zipfile
import lzma
import contextlib
import urllib import urllib
import urllib.request import urllib.request
import urllib.error import urllib.error
import pathlib
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -242,56 +245,145 @@ def _save_response_content( ...@@ -242,56 +245,145 @@ def _save_response_content(
pbar.close() pbar.close()
def _is_tarxz(filename: str) -> bool: def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
return filename.endswith(".tar.xz") with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)
def _is_tar(filename: str) -> bool: _ZIP_COMPRESSION_MAP: Dict[str, int] = {
return filename.endswith(".tar") ".xz": zipfile.ZIP_LZMA,
}
def _is_targz(filename: str) -> bool: def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
return filename.endswith(".tar.gz") with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)
def _is_tgz(filename: str) -> bool: _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
return filename.endswith(".tgz") ".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}
def _is_gzip(filename: str) -> bool: def _verify_archive_type(archive_type: str) -> None:
return filename.endswith(".gz") and not filename.endswith(".tar.gz") if archive_type not in _ARCHIVE_EXTRACTORS.keys():
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")
def _is_zip(filename: str) -> bool: def _verify_compression(compression: str) -> None:
return filename.endswith(".zip") if compression not in _COMPRESSED_FILE_OPENERS.keys():
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None: def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
path = pathlib.Path(file)
suffix = path.suffix
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
elif len(suffixes) > 2:
raise RuntimeError(
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
)
elif len(suffixes) == 2:
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
# compression
archive_type, compression = suffixes
_verify_archive_type(archive_type)
_verify_compression(compression)
return "".join(suffixes), archive_type, compression
# check if the suffix is a known alias
with contextlib.suppress(KeyError):
return (suffix, *_FILE_TYPE_ALIASES[suffix])
# check if the suffix is an archive type
with contextlib.suppress(RuntimeError):
_verify_archive_type(suffix)
return suffix, suffix, None
# check if the suffix is a compression
with contextlib.suppress(RuntimeError):
_verify_compression(suffix)
return suffix, None, suffix
raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
r"""Decompress a file.
The compression is automatically detected from the file name.
Args:
from_path (str): Path to the file to be decompressed.
to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the decompressed file.
"""
suffix, archive_type, compression = _detect_file_type(from_path)
if not compression:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
if to_path is None: if to_path is None:
to_path = os.path.dirname(from_path) to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
if _is_tar(from_path): # We don't need to check for a missing key here, since this was already done in _detect_file_type()
with tarfile.open(from_path, 'r') as tar: compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path): with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
with tarfile.open(from_path, 'r:gz') as tar: wfh.write(rfh.read())
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished: if remove_finished:
os.remove(from_path) os.remove(from_path)
return to_path
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
but not an archive the call is dispatched to :func:`decompress`.
Args:
from_path (str): Path to the file to be extracted.
to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
used.
remove_finished (bool): If ``True``, remove the file after the extraction.
Returns:
(str): Path to the directory the file was extracted to.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
extractor(from_path, to_path, compression)
return to_path
def download_and_extract_archive( def download_and_extract_archive(
url: str, url: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment