Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1de7a74a
Unverified
Commit
1de7a74a
authored
Jan 16, 2024
by
ahmadsharif1
Committed by
GitHub
Jan 16, 2024
Browse files
Added pathlib support to datasets/utils.py (#8200)
parent
a00a72b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
29 deletions
+75
-29
test/test_datasets_utils.py
test/test_datasets_utils.py
+29
-6
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+46
-23
No files found.
test/test_datasets_utils.py
View file @
1de7a74a
...
...
@@ -58,8 +58,11 @@ class TestDatasetsUtils:
assert
mock
.
call_count
==
1
assert
mock
.
call_args
[
0
][
0
].
full_url
==
url
def
test_check_md5
(
self
):
@
pytest
.
mark
.
parametrize
(
"use_pathlib"
,
(
True
,
False
))
def
test_check_md5
(
self
,
use_pathlib
):
fpath
=
TEST_FILE
if
use_pathlib
:
fpath
=
pathlib
.
Path
(
fpath
)
correct_md5
=
"9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5
=
""
assert
utils
.
check_md5
(
fpath
,
correct_md5
)
...
...
@@ -116,7 +119,8 @@ class TestDatasetsUtils:
utils
.
_detect_file_type
(
file
)
@
pytest
.
mark
.
parametrize
(
"extension"
,
[
".bz2"
,
".gz"
,
".xz"
])
def
test_decompress
(
self
,
extension
,
tmpdir
):
@
pytest
.
mark
.
parametrize
(
"use_pathlib"
,
(
True
,
False
))
def
test_decompress
(
self
,
extension
,
tmpdir
,
use_pathlib
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"file"
)
compressed
=
f
"
{
file
}{
extension
}
"
...
...
@@ -128,6 +132,8 @@ class TestDatasetsUtils:
return
compressed
,
file
,
content
compressed
,
file
,
content
=
create_compressed
(
tmpdir
)
if
use_pathlib
:
compressed
=
pathlib
.
Path
(
compressed
)
utils
.
_decompress
(
compressed
)
...
...
@@ -140,7 +146,8 @@ class TestDatasetsUtils:
with
pytest
.
raises
(
RuntimeError
):
utils
.
_decompress
(
"foo.tar"
)
def
test_decompress_remove_finished
(
self
,
tmpdir
):
@
pytest
.
mark
.
parametrize
(
"use_pathlib"
,
(
True
,
False
))
def
test_decompress_remove_finished
(
self
,
tmpdir
,
use_pathlib
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"file"
)
compressed
=
f
"
{
file
}
.gz"
...
...
@@ -151,10 +158,20 @@ class TestDatasetsUtils:
return
compressed
,
file
,
content
compressed
,
file
,
content
=
create_compressed
(
tmpdir
)
print
(
f
"
{
type
(
compressed
)
=
}
"
)
if
use_pathlib
:
compressed
=
pathlib
.
Path
(
compressed
)
tmpdir
=
pathlib
.
Path
(
tmpdir
)
utils
.
extract_archive
(
compressed
,
tmpdir
,
remove_finished
=
True
)
extracted_dir
=
utils
.
extract_archive
(
compressed
,
tmpdir
,
remove_finished
=
True
)
assert
not
os
.
path
.
exists
(
compressed
)
if
use_pathlib
:
assert
isinstance
(
extracted_dir
,
pathlib
.
Path
)
assert
isinstance
(
compressed
,
pathlib
.
Path
)
else
:
assert
isinstance
(
extracted_dir
,
str
)
assert
isinstance
(
compressed
,
str
)
@
pytest
.
mark
.
parametrize
(
"extension"
,
[
".gz"
,
".xz"
])
@
pytest
.
mark
.
parametrize
(
"remove_finished"
,
[
True
,
False
])
...
...
@@ -167,7 +184,8 @@ class TestDatasetsUtils:
mocked
.
assert_called_once_with
(
file
,
filename
,
remove_finished
=
remove_finished
)
def
test_extract_zip
(
self
,
tmpdir
):
@
pytest
.
mark
.
parametrize
(
"use_pathlib"
,
(
True
,
False
))
def
test_extract_zip
(
self
,
tmpdir
,
use_pathlib
):
def
create_archive
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"dst.txt"
)
archive
=
os
.
path
.
join
(
root
,
"archive.zip"
)
...
...
@@ -177,6 +195,8 @@ class TestDatasetsUtils:
return
archive
,
file
,
content
if
use_pathlib
:
tmpdir
=
pathlib
.
Path
(
tmpdir
)
archive
,
file
,
content
=
create_archive
(
tmpdir
)
utils
.
extract_archive
(
archive
,
tmpdir
)
...
...
@@ -189,7 +209,8 @@ class TestDatasetsUtils:
@
pytest
.
mark
.
parametrize
(
"extension, mode"
,
[(
".tar"
,
"w"
),
(
".tar.gz"
,
"w:gz"
),
(
".tgz"
,
"w:gz"
),
(
".tar.xz"
,
"w:xz"
)]
)
def
test_extract_tar
(
self
,
extension
,
mode
,
tmpdir
):
@
pytest
.
mark
.
parametrize
(
"use_pathlib"
,
(
True
,
False
))
def
test_extract_tar
(
self
,
extension
,
mode
,
tmpdir
,
use_pathlib
):
def
create_archive
(
root
,
extension
,
mode
,
content
=
"this is the content"
):
src
=
os
.
path
.
join
(
root
,
"src.txt"
)
dst
=
os
.
path
.
join
(
root
,
"dst.txt"
)
...
...
@@ -203,6 +224,8 @@ class TestDatasetsUtils:
return
archive
,
dst
,
content
if
use_pathlib
:
tmpdir
=
pathlib
.
Path
(
tmpdir
)
archive
,
file
,
content
=
create_archive
(
tmpdir
,
extension
,
mode
)
utils
.
extract_archive
(
archive
,
tmpdir
)
...
...
torchvision/datasets/utils.py
View file @
1de7a74a
...
...
@@ -30,7 +30,7 @@ USER_AGENT = "pytorch/vision"
def
_save_response_content
(
content
:
Iterator
[
bytes
],
destination
:
str
,
destination
:
Union
[
str
,
pathlib
.
Path
]
,
length
:
Optional
[
int
]
=
None
,
)
->
None
:
with
open
(
destination
,
"wb"
)
as
fh
,
tqdm
(
total
=
length
)
as
pbar
:
...
...
@@ -43,12 +43,12 @@ def _save_response_content(
pbar
.
update
(
len
(
chunk
))
def
_urlretrieve
(
url
:
str
,
filename
:
str
,
chunk_size
:
int
=
1024
*
32
)
->
None
:
def
_urlretrieve
(
url
:
str
,
filename
:
Union
[
str
,
pathlib
.
Path
]
,
chunk_size
:
int
=
1024
*
32
)
->
None
:
with
urllib
.
request
.
urlopen
(
urllib
.
request
.
Request
(
url
,
headers
=
{
"User-Agent"
:
USER_AGENT
}))
as
response
:
_save_response_content
(
iter
(
lambda
:
response
.
read
(
chunk_size
),
b
""
),
filename
,
length
=
response
.
length
)
def
calculate_md5
(
fpath
:
str
,
chunk_size
:
int
=
1024
*
1024
)
->
str
:
def
calculate_md5
(
fpath
:
Union
[
str
,
pathlib
.
Path
]
,
chunk_size
:
int
=
1024
*
1024
)
->
str
:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
# it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
...
...
@@ -62,11 +62,11 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
return
md5
.
hexdigest
()
def
check_md5
(
fpath
:
str
,
md5
:
str
,
**
kwargs
:
Any
)
->
bool
:
def
check_md5
(
fpath
:
Union
[
str
,
pathlib
.
Path
]
,
md5
:
str
,
**
kwargs
:
Any
)
->
bool
:
return
md5
==
calculate_md5
(
fpath
,
**
kwargs
)
def
check_integrity
(
fpath
:
str
,
md5
:
Optional
[
str
]
=
None
)
->
bool
:
def
check_integrity
(
fpath
:
Union
[
str
,
pathlib
.
Path
]
,
md5
:
Optional
[
str
]
=
None
)
->
bool
:
if
not
os
.
path
.
isfile
(
fpath
):
return
False
if
md5
is
None
:
...
...
@@ -106,7 +106,7 @@ def _get_google_drive_file_id(url: str) -> Optional[str]:
def
download_url
(
url
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
],
filename
:
Optional
[
str
]
=
None
,
filename
:
Optional
[
Union
[
str
,
pathlib
.
Path
]
]
=
None
,
md5
:
Optional
[
str
]
=
None
,
max_redirect_hops
:
int
=
3
,
)
->
None
:
...
...
@@ -159,7 +159,7 @@ def download_url(
raise
RuntimeError
(
"File not found or corrupted."
)
def
list_dir
(
root
:
str
,
prefix
:
bool
=
False
)
->
List
[
str
]:
def
list_dir
(
root
:
Union
[
str
,
pathlib
.
Path
]
,
prefix
:
bool
=
False
)
->
List
[
str
]:
"""List all directories at a given root
Args:
...
...
@@ -174,7 +174,7 @@ def list_dir(root: str, prefix: bool = False) -> List[str]:
return
directories
def
list_files
(
root
:
str
,
suffix
:
str
,
prefix
:
bool
=
False
)
->
List
[
str
]:
def
list_files
(
root
:
Union
[
str
,
pathlib
.
Path
]
,
suffix
:
str
,
prefix
:
bool
=
False
)
->
List
[
str
]:
"""List all files ending with a suffix at a given root
Args:
...
...
@@ -208,7 +208,10 @@ def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple
def
download_file_from_google_drive
(
file_id
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
],
filename
:
Optional
[
str
]
=
None
,
md5
:
Optional
[
str
]
=
None
file_id
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
],
filename
:
Optional
[
Union
[
str
,
pathlib
.
Path
]]
=
None
,
md5
:
Optional
[
str
]
=
None
,
):
"""Download a Google Drive file from and place it in root.
...
...
@@ -278,7 +281,9 @@ def download_file_from_google_drive(
)
def
_extract_tar
(
from_path
:
str
,
to_path
:
str
,
compression
:
Optional
[
str
])
->
None
:
def
_extract_tar
(
from_path
:
Union
[
str
,
pathlib
.
Path
],
to_path
:
Union
[
str
,
pathlib
.
Path
],
compression
:
Optional
[
str
]
)
->
None
:
with
tarfile
.
open
(
from_path
,
f
"r:
{
compression
[
1
:]
}
"
if
compression
else
"r"
)
as
tar
:
tar
.
extractall
(
to_path
)
...
...
@@ -289,14 +294,16 @@ _ZIP_COMPRESSION_MAP: Dict[str, int] = {
}
def
_extract_zip
(
from_path
:
str
,
to_path
:
str
,
compression
:
Optional
[
str
])
->
None
:
def
_extract_zip
(
from_path
:
Union
[
str
,
pathlib
.
Path
],
to_path
:
Union
[
str
,
pathlib
.
Path
],
compression
:
Optional
[
str
]
)
->
None
:
with
zipfile
.
ZipFile
(
from_path
,
"r"
,
compression
=
_ZIP_COMPRESSION_MAP
[
compression
]
if
compression
else
zipfile
.
ZIP_STORED
)
as
zip
:
zip
.
extractall
(
to_path
)
_ARCHIVE_EXTRACTORS
:
Dict
[
str
,
Callable
[[
str
,
str
,
Optional
[
str
]],
None
]]
=
{
_ARCHIVE_EXTRACTORS
:
Dict
[
str
,
Callable
[[
Union
[
str
,
pathlib
.
Path
],
Union
[
str
,
pathlib
.
Path
]
,
Optional
[
str
]],
None
]]
=
{
".tar"
:
_extract_tar
,
".zip"
:
_extract_zip
,
}
...
...
@@ -312,7 +319,7 @@ _FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
}
def
_detect_file_type
(
file
:
str
)
->
Tuple
[
str
,
Optional
[
str
],
Optional
[
str
]]:
def
_detect_file_type
(
file
:
Union
[
str
,
pathlib
.
Path
]
)
->
Tuple
[
str
,
Optional
[
str
],
Optional
[
str
]]:
"""Detect the archive type and/or compression of a file.
Args:
...
...
@@ -355,7 +362,11 @@ def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
raise
RuntimeError
(
f
"Unknown compression or archive type: '
{
suffix
}
'.
\n
Known suffixes are: '
{
valid_suffixes
}
'."
)
def
_decompress
(
from_path
:
str
,
to_path
:
Optional
[
str
]
=
None
,
remove_finished
:
bool
=
False
)
->
str
:
def
_decompress
(
from_path
:
Union
[
str
,
pathlib
.
Path
],
to_path
:
Optional
[
Union
[
str
,
pathlib
.
Path
]]
=
None
,
remove_finished
:
bool
=
False
,
)
->
pathlib
.
Path
:
r
"""Decompress a file.
The compression is automatically detected from the file name.
...
...
@@ -373,7 +384,7 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
raise
RuntimeError
(
f
"Couldn't detect a compression from suffix
{
suffix
}
."
)
if
to_path
is
None
:
to_path
=
from_path
.
replace
(
suffix
,
archive_type
if
archive_type
is
not
None
else
""
)
to_path
=
pathlib
.
Path
(
os
.
fspath
(
from_path
)
.
replace
(
suffix
,
archive_type
if
archive_type
is
not
None
else
""
)
)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
compressed_file_opener
=
_COMPRESSED_FILE_OPENERS
[
compression
]
...
...
@@ -384,10 +395,14 @@ def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished:
if
remove_finished
:
os
.
remove
(
from_path
)
return
to_path
return
pathlib
.
Path
(
to_path
)
def
extract_archive
(
from_path
:
str
,
to_path
:
Optional
[
str
]
=
None
,
remove_finished
:
bool
=
False
)
->
str
:
def
extract_archive
(
from_path
:
Union
[
str
,
pathlib
.
Path
],
to_path
:
Optional
[
Union
[
str
,
pathlib
.
Path
]]
=
None
,
remove_finished
:
bool
=
False
,
)
->
Union
[
str
,
pathlib
.
Path
]:
"""Extract an archive.
The archive type and a possible compression is automatically detected from the file name. If the file is compressed
...
...
@@ -402,16 +417,24 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
Returns:
(str): Path to the directory the file was extracted to.
"""
def
path_or_str
(
ret_path
:
pathlib
.
Path
)
->
Union
[
str
,
pathlib
.
Path
]:
if
isinstance
(
from_path
,
str
):
return
os
.
fspath
(
ret_path
)
else
:
return
ret_path
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
suffix
,
archive_type
,
compression
=
_detect_file_type
(
from_path
)
if
not
archive_type
:
ret
urn
_decompress
(
ret
_path
=
_decompress
(
from_path
,
os
.
path
.
join
(
to_path
,
os
.
path
.
basename
(
from_path
).
replace
(
suffix
,
""
)),
remove_finished
=
remove_finished
,
)
return
path_or_str
(
ret_path
)
# We don't need to check for a missing key here, since this was already done in _detect_file_type()
extractor
=
_ARCHIVE_EXTRACTORS
[
archive_type
]
...
...
@@ -420,14 +443,14 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finish
if
remove_finished
:
os
.
remove
(
from_path
)
return
to_path
return
path_or_str
(
pathlib
.
Path
(
to_path
))
def
download_and_extract_archive
(
url
:
str
,
download_root
:
str
,
extract_root
:
Optional
[
str
]
=
None
,
filename
:
Optional
[
str
]
=
None
,
download_root
:
Union
[
str
,
pathlib
.
Path
]
,
extract_root
:
Optional
[
Union
[
str
,
pathlib
.
Path
]
]
=
None
,
filename
:
Optional
[
Union
[
str
,
pathlib
.
Path
]
]
=
None
,
md5
:
Optional
[
str
]
=
None
,
remove_finished
:
bool
=
False
,
)
->
None
:
...
...
@@ -479,7 +502,7 @@ def verify_str_arg(
return
value
def
_read_pfm
(
file_name
:
str
,
slice_channels
:
int
=
2
)
->
np
.
ndarray
:
def
_read_pfm
(
file_name
:
Union
[
str
,
pathlib
.
Path
]
,
slice_channels
:
int
=
2
)
->
np
.
ndarray
:
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment