Unverified Commit eb1b9827 authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Add bzip2 file compression support to datasets (#4097)

parent 183a7221
import bz2
import os
import torchvision.datasets.utils as utils
import unittest
......@@ -51,10 +52,14 @@ class Tester(unittest.TestCase):
def test_detect_file_type(self):
for file, expected in [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
......@@ -82,6 +87,26 @@ class Tester(unittest.TestCase):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar")
def test_decompress_bz2(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.bz2"
with bz2.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils._decompress(compressed)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
......
import bz2
import os
import os.path
import hashlib
......@@ -262,6 +263,7 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
_ZIP_COMPRESSION_MAP: Dict[str, int] = {
".bz2": zipfile.ZIP_BZIP2,
".xz": zipfile.ZIP_LZMA,
}
......@@ -277,8 +279,16 @@ _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
".tbz": (".tar", ".bz2"),
".tbz2": (".tar", ".bz2"),
".tgz": (".tar", ".gz"),
}
def _verify_archive_type(archive_type: str) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment