Unverified Commit 1de7a74a authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Added pathlib support to datasets/utils.py (#8200)

parent a00a72b1
......@@ -58,8 +58,11 @@ class TestDatasetsUtils:
assert mock.call_count == 1
assert mock.call_args[0][0].full_url == url
def test_check_md5(self):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_check_md5(self, use_pathlib):
fpath = TEST_FILE
if use_pathlib:
fpath = pathlib.Path(fpath)
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
......@@ -116,7 +119,8 @@ class TestDatasetsUtils:
utils._detect_file_type(file)
@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress(self, extension, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
......@@ -128,6 +132,8 @@ class TestDatasetsUtils:
return compressed, file, content
compressed, file, content = create_compressed(tmpdir)
if use_pathlib:
compressed = pathlib.Path(compressed)
utils._decompress(compressed)
......@@ -140,7 +146,8 @@ class TestDatasetsUtils:
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")
def test_decompress_remove_finished(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_decompress_remove_finished(self, tmpdir, use_pathlib):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
......@@ -151,10 +158,20 @@ class TestDatasetsUtils:
return compressed, file, content
compressed, file, content = create_compressed(tmpdir)
print(f"{type(compressed)=}")
if use_pathlib:
compressed = pathlib.Path(compressed)
tmpdir = pathlib.Path(tmpdir)
utils.extract_archive(compressed, tmpdir, remove_finished=True)
extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)
assert not os.path.exists(compressed)
if use_pathlib:
assert isinstance(extracted_dir, pathlib.Path)
assert isinstance(compressed, pathlib.Path)
else:
assert isinstance(extracted_dir, str)
assert isinstance(compressed, str)
@pytest.mark.parametrize("extension", [".gz", ".xz"])
@pytest.mark.parametrize("remove_finished", [True, False])
......@@ -167,7 +184,8 @@ class TestDatasetsUtils:
mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
def test_extract_zip(self, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_zip(self, tmpdir, use_pathlib):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
......@@ -177,6 +195,8 @@ class TestDatasetsUtils:
return archive, file, content
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir)
utils.extract_archive(archive, tmpdir)
......@@ -189,7 +209,8 @@ class TestDatasetsUtils:
@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
@pytest.mark.parametrize("use_pathlib", (True, False))
def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
......@@ -203,6 +224,8 @@ class TestDatasetsUtils:
return archive, dst, content
if use_pathlib:
tmpdir = pathlib.Path(tmpdir)
archive, file, content = create_archive(tmpdir, extension, mode)
utils.extract_archive(archive, tmpdir)
......
......@@ -30,7 +30,7 @@ USER_AGENT = "pytorch/vision"
def _save_response_content(
content: Iterator[bytes],
destination: str,
destination: Union[str, pathlib.Path],
length: Optional[int] = None,
) -> None:
with open(destination, "wb") as fh, tqdm(total=length) as pbar:
......@@ -43,12 +43,12 @@ def _save_response_content(
pbar.update(len(chunk))
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
def _urlretrieve(url: str, filename: Union[str, pathlib.Path], chunk_size: int = 1024 * 32) -> None:
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
def calculate_md5(fpath: Union[str, pathlib.Path], chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
......@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return md5.hexdigest()
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
def check_md5(fpath: Union[str, pathlib.Path], md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
def check_integrity(fpath: Union[str, pathlib.Path], md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
......@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def download_url(
url: str,
root: Union[str, pathlib.Path],
filename: Optional[str] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
max_redirect_hops: int = 3,
) -> None:
......@@ -159,7 +159,7 @@ def download_url(
raise RuntimeError("File not found or corrupted.")
def list_dir(root: str, prefix: bool = False) -> List[str]:
def list_dir(root: Union[str, pathlib.Path], prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
......@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return directories
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
......@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
def download_file_from_google_drive(
file_id: str, root: Union[str, pathlib.Path], filename: Optional[str] = None, md5: Optional[str] = None
file_id: str,
root: Union[str, pathlib.Path],
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
):
"""Download a Google Drive file from and place it in root.
......@@ -278,7 +281,9 @@ def download_file_from_google_drive(
)
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_tar(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
tar.extractall(to_path)
......@@ -289,14 +294,16 @@ _ZIP_COMPRESSION_MAP: Dict[str, int] = {
}
def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
def _extract_zip(
from_path: Union[str, pathlib.Path], to_path: Union[str, pathlib.Path], compression: Optional[str]
) -> None:
with zipfile.ZipFile(
from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
) as zip:
zip.extractall(to_path)
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
_ARCHIVE_EXTRACTORS: Dict[str, Callable[[Union[str, pathlib.Path], Union[str, pathlib.Path], Optional[str]], None]] = {
".tar": _extract_tar,
".zip": _extract_zip,
}
......@@ -312,7 +319,7 @@ _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
}
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
def _detect_file_type(file: Union[str, pathlib.Path]) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
Args:
......@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def _decompress(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> pathlib.Path:
r"""Decompress a file.
The compression is automatically detected from the file name.
......@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
if to_path is None:
to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
to_path = pathlib.Path(os.fspath(from_path).replace(suffix, archive_type if archive_type is not None else ""))
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
......@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if remove_finished:
os.remove(from_path)
return to_path
return pathlib.Path(to_path)
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
def extract_archive(
from_path: Union[str, pathlib.Path],
to_path: Optional[Union[str, pathlib.Path]] = None,
remove_finished: bool = False,
) -> Union[str, pathlib.Path]:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
......@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns:
(str): Path to the directory the file was extracted to.
"""
def path_or_str(ret_path: pathlib.Path) -> Union[str, pathlib.Path]:
if isinstance(from_path, str):
return os.fspath(ret_path)
else:
return ret_path
if to_path is None:
to_path = os.path.dirname(from_path)
suffix, archive_type, compression = _detect_file_type(from_path)
if not archive_type:
return _decompress(
ret_path = _decompress(
from_path,
os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
remove_finished=remove_finished,
)
return path_or_str(ret_path)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor = _ARCHIVE_EXTRACTORS[archive_type]
......@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if remove_finished:
os.remove(from_path)
return to_path
return path_or_str(pathlib.Path(to_path))
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
download_root: Union[str, pathlib.Path],
extract_root: Optional[Union[str, pathlib.Path]] = None,
filename: Optional[Union[str, pathlib.Path]] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
......@@ -479,7 +502,7 @@ def verify_str_arg(
return value
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
def _read_pfm(file_name: Union[str, pathlib.Path], slice_channels: int = 2) -> np.ndarray:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment