Unverified Commit 1de7a74a authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Added pathlib support to datasets/utils.py (#8200)

parent a00a72b1
...@@ -58,8 +58,11 @@ class TestDatasetsUtils: ...@@ -58,8 +58,11 @@ class TestDatasetsUtils:
assert mock.call_count == 1 assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url assert mock.call_args[0][0].full_url == url
def test_check_md5(self): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = "" false_md5 = ""
assert utils.check_md5(fpath, correct_md5) assert utils.check_md5(fpath, correct_md5)
...@@ -116,7 +119,8 @@ class TestDatasetsUtils: ...@@ -116,7 +119,8 @@ class TestDatasetsUtils:
utils._detect_file_type(file) utils._detect_file_type(file)
@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"]) @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
compressed = f"{file}{extension}" compressed = f"{file}{extension}"
...@@ -128,6 +132,8 @@ class TestDatasetsUtils: ...@@ -128,6 +132,8 @@ class TestDatasetsUtils:
return compressed, file, content return compressed, file, content
compressed, file, content = create_compressed(tmpdir) compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)
utils._decompress(compressed) utils._decompress(compressed)
...@@ -140,7 +146,8 @@ class TestDatasetsUtils: ...@@ -140,7 +146,8 @@ class TestDatasetsUtils:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
utils._decompress("foo.tar") utils._decompress("foo.tar")
def test_decompress_remove_finished(self, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"): def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file") file = os.path.join(root, "file")
compressed = f"{file}.gz" compressed = f"{file}.gz"
...@@ -151,10 +158,20 @@ class TestDatasetsUtils: ...@@ -151,10 +158,20 @@ class TestDatasetsUtils:
return compressed, file, content return compressed, file, content
compressed, file, content = create_compressed(tmpdir) compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)
utils.extract_archive(compressed, tmpdir, remove_finished=True) extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)
assert not os.path.exists(compressed) assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)
@pytest.mark.parametrize("extension", [".gz", ".xz"]) @pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False]) @pytest.mark.parametrize("remove_finished", [True, False])
...@@ -167,7 +184,8 @@ class TestDatasetsUtils: ...@@ -167,7 +184,8 @@ class TestDatasetsUtils:
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
def test_extract_zip(self, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"): def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt") file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip") archive = os.path.join(root, "archive.zip")
...@@ -177,6 +195,8 @@ class TestDatasetsUtils: ...@@ -177,6 +195,8 @@ class TestDatasetsUtils:
return archive, file, content return archive, file, content
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir) archive, file, content = create_archive(tmpdir)
utils.extract_archive(archive, tmpdir) utils.extract_archive(archive, tmpdir)
...@@ -189,7 +209,8 @@ class TestDatasetsUtils: ...@@ -189,7 +209,8 @@ class TestDatasetsUtils:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")] "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
) )
def test_extract_tar(self, extension, mode, tmpdir): @pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"): def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt") src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt") dst = os.path.join(root, "dst.txt")
...@@ -203,6 +224,8 @@ class TestDatasetsUtils: ...@@ -203,6 +224,8 @@ class TestDatasetsUtils:
return archive, dst, content return archive, dst, content
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode) archive, file, content = create_archive(tmpdir, extension, mode)
utils.extract_archive(archive, tmpdir) utils.extract_archive(archive, tmpdir)
......
...@@ -30,7 +30,7 @@ USER_AGENT = "pytorch/vision" ...@@ -30,7 +30,7 @@ USER_AGENT = "pytorch/vision"
def _save_response_content( def _save_response_content(
content: Iterator[bytes], content: Iterator[bytes],
destination: str, destination: Union[str, pathlib.Path],
length: Optional[int] = None, length: Optional[int] = None,
) -> None: ) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar: with open(destination, "wb") as fh, tqdm(total=length) as pbar:
...@@ -43,12 +43,12 @@ def _save_response_content( ...@@ -43,12 +43,12 @@ def _save_response_content(
pbar.update(len(chunk)) pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
...@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: ...@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return md5.hexdigest() return md5.hexdigest()
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs) return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
return False return False
if md5 is None: if md5 is None:
...@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]: ...@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url( def download_url(
url: str, url: str,
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
filename: Optional[str] = None, filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None, md5: Optional[str] = None,
max_redirect_hops: int = 3, max_redirect_hops: int = 3,
) -> None: ) -> None:
...@@ -159,7 +159,7 @@ def download_url( ...@@ -159,7 +159,7 @@ def download_url(
raise RuntimeError("File not found or corrupted.") raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]: def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
"""List all directories at a given root """List all directories at a given root
Args: Args:
...@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]: ...@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return directories return directories
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root """List all files ending with a suffix at a given root
Args: Args:
...@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple ...@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
def download_file_from_google_drive( def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
): ):
"""Download a Google Drive file from and place it in root. """Download a Google Drive file from and place it in root.
...@@ -278,7 +281,9 @@ def download_file_from_google_drive( ...@@ -278,7 +281,9 @@ def download_file_from_google_drive(
) )
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path) tar.extractall(to_path)
...@@ -289,14 +294,16 @@ _ZIP_COMPRESSION_MAP: Dict[str, int] = { ...@@ -289,14 +294,16 @@ _ZIP_COMPRESSION_MAP: Dict[str, int] = {
} }
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile( with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip: ) as zip:
zip.extractall(to_path) zip.extractall(to_path)
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { _ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar, ".tar": _extract_tar,
".zip": _extract_zip, ".zip": _extract_zip,
} }
...@@ -312,7 +319,7 @@ _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { ...@@ -312,7 +319,7 @@ _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
} }
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file. """Detect the archive type and/or compression of a file.
Args: Args:
...@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: ...@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file. r"""Decompress a file.
The compression is automatically detected from the file name. The compression is automatically detected from the file name.
...@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: ...@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
if to_path is None: if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
# We don't need to check for a missing key here, since this was already done in _detect_file_type() # We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
...@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: ...@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if remove_finished: if remove_finished:
os.remove(from_path) os.remove(from_path)
return to_path return pathlib.Path(to_path)
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: def extract_archive(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive. """Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed The archive type and a possible compression is automatically detected from the file name. If the file is compressed
...@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish ...@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns: Returns:
(str): Path to the directory the file was extracted to. (str): Path to the directory the file was extracted to.
""" """
def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path
if to_path is None: if to_path is None:
to_path = os.path.dirname(from_path) to_path = os.path.dirname(from_path)
suffix, archive_type, compression = _detect_file_type(from_path) suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type: if not archive_type:
return _decompress( ret_path = _decompress(
from_path, from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished, remove_finished=remove_finished,
) )
return path_or_str(ret_path)
# We don't need to check for a missing key here, since this was already done in _detect_file_type() # We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type] extractor = _ARCHIVE_EXTRACTORS[archive_type]
...@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish ...@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if remove_finished: if remove_finished:
os.remove(from_path) os.remove(from_path)
return to_path return path_or_str(pathlib.Path(to_path))
def download_and_extract_archive( def download_and_extract_archive(
url: str, url: str,
download_root: str, download_root: Union[str, pathlib.Path],
extract_root: Optional[str] = None, extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[str] = None, filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None, md5: Optional[str] = None,
remove_finished: bool = False, remove_finished: bool = False,
) -> None: ) -> None:
...@@ -479,7 +502,7 @@ def verify_str_arg( ...@@ -479,7 +502,7 @@ def verify_str_arg(
return value return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data. """Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment