"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58d2b10a2e9cd32dd9765dc50aca98690f516287"
Unverified Commit eb1b9827 authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Add bzip2 file compression support to datasets (#4097)

parent 183a7221
import bz2
import os import os
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
import unittest import unittest
...@@ -51,10 +52,14 @@ class Tester(unittest.TestCase): ...@@ -51,10 +52,14 @@ class Tester(unittest.TestCase):
def test_detect_file_type(self): def test_detect_file_type(self):
for file, expected in [ for file, expected in [
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")), ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)), ("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")), ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")), ("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")), ("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)), ("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")), ("foo.xz", (".xz", None, ".xz")),
...@@ -82,6 +87,26 @@ class Tester(unittest.TestCase): ...@@ -82,6 +87,26 @@ class Tester(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar") utils._detect_file_type("foo.bar")
def test_decompress_bz2(self):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.bz2"
with bz2.open(compressed, "wb") as fh:
fh.write(content.encode())
return compressed, file, content
with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
utils._decompress(compressed)
self.assertTrue(os.path.exists(file))
with open(file, "r") as fh:
self.assertEqual(fh.read(), content)
def test_decompress_gzip(self): def test_decompress_gzip(self):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
......
import bz2
import os import os
import os.path import os.path
import hashlib import hashlib
...@@ -262,6 +263,7 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No ...@@ -262,6 +263,7 @@ def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> No
_ZIP_COMPRESSION_MAP: Dict[str, int] = { _ZIP_COMPRESSION_MAP: Dict[str, int] = {
".bz2": zipfile.ZIP_BZIP2,
".xz": zipfile.ZIP_LZMA, ".xz": zipfile.ZIP_LZMA,
} }
...@@ -277,8 +279,16 @@ _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { ...@@ -277,8 +279,16 @@ _ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
".tar": _extract_tar, ".tar": _extract_tar,
".zip": _extract_zip, ".zip": _extract_zip,
} }
_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {".gz": gzip.open, ".xz": lzma.open} _COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")} ".bz2": bz2.open,
".gz": gzip.open,
".xz": lzma.open,
}
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
".tbz": (".tar", ".bz2"),
".tbz2": (".tar", ".bz2"),
".tgz": (".tar", ".gz"),
}
def _verify_archive_type(archive_type: str) -> None: def _verify_archive_type(archive_type: str) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment