Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b3b51377
Unverified
Commit
b3b51377
authored
Jun 25, 2021
by
Anirudh
Committed by
GitHub
Jun 25, 2021
Browse files
Port test_datasets_utils to pytest (#4114)
parent
ab60e538
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
159 deletions
+69
-159
test/common_utils.py
test/common_utils.py
+0
-17
test/test_datasets_utils.py
test/test_datasets_utils.py
+69
-142
No files found.
test/common_utils.py
View file @
b3b51377
...
@@ -240,23 +240,6 @@ def disable_console_output():
...
@@ -240,23 +240,6 @@ def disable_console_output():
yield
yield
def
call_args_to_kwargs_only
(
call_args
,
*
callable_or_arg_names
):
callable_or_arg_name
=
callable_or_arg_names
[
0
]
if
callable
(
callable_or_arg_name
):
argspec
=
inspect
.
getfullargspec
(
callable_or_arg_name
)
arg_names
=
argspec
.
args
if
isinstance
(
callable_or_arg_name
,
type
):
# remove self
arg_names
.
pop
(
0
)
else
:
arg_names
=
callable_or_arg_names
args
,
kwargs
=
call_args
kwargs_only
=
kwargs
.
copy
()
kwargs_only
.
update
(
dict
(
zip
(
arg_names
,
args
)))
return
kwargs_only
def
cpu_and_gpu
():
def
cpu_and_gpu
():
import
pytest
# noqa
import
pytest
# noqa
return
(
'cpu'
,
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
needs_cuda
))
return
(
'cpu'
,
pytest
.
param
(
'cuda'
,
marks
=
pytest
.
mark
.
needs_cuda
))
...
...
test/test_datasets_utils.py
View file @
b3b51377
import
bz2
import
bz2
import
os
import
os
import
torchvision.datasets.utils
as
utils
import
torchvision.datasets.utils
as
utils
import
unittest
import
pytest
import
unittest.mock
import
zipfile
import
zipfile
import
tarfile
import
tarfile
import
gzip
import
gzip
...
@@ -12,31 +11,32 @@ from urllib.error import URLError
...
@@ -12,31 +11,32 @@ from urllib.error import URLError
import
itertools
import
itertools
import
lzma
import
lzma
from
common_utils
import
get_tmp_dir
,
call_args_to_kwargs_only
from
common_utils
import
get_tmp_dir
from
torchvision.datasets.utils
import
_COMPRESSED_FILE_OPENERS
TEST_FILE
=
get_file_path_2
(
TEST_FILE
=
get_file_path_2
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'assets'
,
'encode_jpeg'
,
'grace_hopper_517x606.jpg'
)
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'assets'
,
'encode_jpeg'
,
'grace_hopper_517x606.jpg'
)
class
Test
er
(
unittest
.
TestCase
)
:
class
Test
DatasetsUtils
:
def
test_check_md5
(
self
):
def
test_check_md5
(
self
):
fpath
=
TEST_FILE
fpath
=
TEST_FILE
correct_md5
=
'9c0bb82894bb3af7f7675ef2b3b6dcdc'
correct_md5
=
'9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5
=
''
false_md5
=
''
self
.
assert
True
(
utils
.
check_md5
(
fpath
,
correct_md5
)
)
assert
utils
.
check_md5
(
fpath
,
correct_md5
)
self
.
assert
False
(
utils
.
check_md5
(
fpath
,
false_md5
)
)
assert
not
utils
.
check_md5
(
fpath
,
false_md5
)
def
test_check_integrity
(
self
):
def
test_check_integrity
(
self
):
existing_fpath
=
TEST_FILE
existing_fpath
=
TEST_FILE
nonexisting_fpath
=
''
nonexisting_fpath
=
''
correct_md5
=
'9c0bb82894bb3af7f7675ef2b3b6dcdc'
correct_md5
=
'9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5
=
''
false_md5
=
''
self
.
assert
True
(
utils
.
check_integrity
(
existing_fpath
,
correct_md5
)
)
assert
utils
.
check_integrity
(
existing_fpath
,
correct_md5
)
self
.
assert
False
(
utils
.
check_integrity
(
existing_fpath
,
false_md5
)
)
assert
not
utils
.
check_integrity
(
existing_fpath
,
false_md5
)
self
.
assert
True
(
utils
.
check_integrity
(
existing_fpath
)
)
assert
utils
.
check_integrity
(
existing_fpath
)
self
.
assert
False
(
utils
.
check_integrity
(
nonexisting_fpath
)
)
assert
not
utils
.
check_integrity
(
nonexisting_fpath
)
def
test_get_google_drive_file_id
(
self
):
def
test_get_google_drive_file_id
(
self
):
url
=
"https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
url
=
"https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
...
@@ -50,8 +50,7 @@ class Tester(unittest.TestCase):
...
@@ -50,8 +50,7 @@ class Tester(unittest.TestCase):
assert
utils
.
_get_google_drive_file_id
(
url
)
is
None
assert
utils
.
_get_google_drive_file_id
(
url
)
is
None
def
test_detect_file_type
(
self
):
@
pytest
.
mark
.
parametrize
(
'file, expected'
,
[
for
file
,
expected
in
[
(
"foo.tar.bz2"
,
(
".tar.bz2"
,
".tar"
,
".bz2"
)),
(
"foo.tar.bz2"
,
(
".tar.bz2"
,
".tar"
,
".bz2"
)),
(
"foo.tar.xz"
,
(
".tar.xz"
,
".tar"
,
".xz"
)),
(
"foo.tar.xz"
,
(
".tar.xz"
,
".tar"
,
".xz"
)),
(
"foo.tar"
,
(
".tar"
,
".tar"
,
None
)),
(
"foo.tar"
,
(
".tar"
,
".tar"
,
None
)),
...
@@ -65,29 +64,24 @@ class Tester(unittest.TestCase):
...
@@ -65,29 +64,24 @@ class Tester(unittest.TestCase):
(
"foo.xz"
,
(
".xz"
,
None
,
".xz"
)),
(
"foo.xz"
,
(
".xz"
,
None
,
".xz"
)),
(
"foo.bar.tar.gz"
,
(
".tar.gz"
,
".tar"
,
".gz"
)),
(
"foo.bar.tar.gz"
,
(
".tar.gz"
,
".tar"
,
".gz"
)),
(
"foo.bar.gz"
,
(
".gz"
,
None
,
".gz"
)),
(
"foo.bar.gz"
,
(
".gz"
,
None
,
".gz"
)),
(
"foo.bar.zip"
,
(
".zip"
,
".zip"
,
None
)),
(
"foo.bar.zip"
,
(
".zip"
,
".zip"
,
None
))])
]:
def
test_detect_file_type
(
self
,
file
,
expected
):
with
self
.
subTest
(
file
=
file
):
assert
utils
.
_detect_file_type
(
file
)
==
expected
self
.
assertSequenceEqual
(
utils
.
_detect_file_type
(
file
),
expected
)
@
pytest
.
mark
.
parametrize
(
'file'
,
[
"foo"
,
"foo.tar.baz"
,
"foo.bar"
])
def
test_detect_file_type_no_ext
(
self
):
def
test_detect_file_type_incompatible
(
self
,
file
):
with
self
.
assertRaises
(
RuntimeError
):
# tests detect file type for no extension, unknown compression and unknown partial extension
utils
.
_detect_file_type
(
"foo"
)
with
pytest
.
raises
(
RuntimeError
):
utils
.
_detect_file_type
(
file
)
def
test_detect_file_type_unknown_compression
(
self
):
with
self
.
assertRaises
(
RuntimeError
):
@
pytest
.
mark
.
parametrize
(
'extension'
,
[
".bz2"
,
".gz"
,
".xz"
])
utils
.
_detect_file_type
(
"foo.tar.baz"
)
def
test_decompress
(
self
,
extension
):
def
test_detect_file_type_unknown_partial_ext
(
self
):
with
self
.
assertRaises
(
RuntimeError
):
utils
.
_detect_file_type
(
"foo.bar"
)
def
test_decompress_bz2
(
self
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"file"
)
file
=
os
.
path
.
join
(
root
,
"file"
)
compressed
=
f
"
{
file
}
.bz2"
compressed
=
f
"
{
file
}{
extension
}
"
compressed_file_opener
=
_COMPRESSED_FILE_OPENERS
[
extension
]
with
bz2
.
open
(
compressed
,
"wb"
)
as
fh
:
with
compressed_file_
open
er
(
compressed
,
"wb"
)
as
fh
:
fh
.
write
(
content
.
encode
())
fh
.
write
(
content
.
encode
())
return
compressed
,
file
,
content
return
compressed
,
file
,
content
...
@@ -97,53 +91,13 @@ class Tester(unittest.TestCase):
...
@@ -97,53 +91,13 @@ class Tester(unittest.TestCase):
utils
.
_decompress
(
compressed
)
utils
.
_decompress
(
compressed
)
self
.
assertTrue
(
os
.
path
.
exists
(
file
))
assert
os
.
path
.
exists
(
file
)
with
open
(
file
,
"r"
)
as
fh
:
self
.
assertEqual
(
fh
.
read
(),
content
)
def
test_decompress_gzip
(
self
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"file"
)
compressed
=
f
"
{
file
}
.gz"
with
gzip
.
open
(
compressed
,
"wb"
)
as
fh
:
fh
.
write
(
content
.
encode
())
return
compressed
,
file
,
content
with
get_tmp_dir
()
as
temp_dir
:
compressed
,
file
,
content
=
create_compressed
(
temp_dir
)
utils
.
_decompress
(
compressed
)
self
.
assertTrue
(
os
.
path
.
exists
(
file
))
with
open
(
file
,
"r"
)
as
fh
:
self
.
assertEqual
(
fh
.
read
(),
content
)
def
test_decompress_lzma
(
self
):
def
create_compressed
(
root
,
content
=
"this is the content"
):
file
=
os
.
path
.
join
(
root
,
"file"
)
compressed
=
f
"
{
file
}
.xz"
with
lzma
.
open
(
compressed
,
"wb"
)
as
fh
:
fh
.
write
(
content
.
encode
())
return
compressed
,
file
,
content
with
get_tmp_dir
()
as
temp_dir
:
compressed
,
file
,
content
=
create_compressed
(
temp_dir
)
utils
.
extract_archive
(
compressed
,
temp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
file
))
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
,
"r"
)
as
fh
:
self
.
assert
Equal
(
fh
.
read
()
,
content
)
assert
fh
.
read
()
==
content
def
test_decompress_no_compression
(
self
):
def
test_decompress_no_compression
(
self
):
with
self
.
assertR
aises
(
RuntimeError
):
with
pytest
.
r
aises
(
RuntimeError
):
utils
.
_decompress
(
"foo.tar"
)
utils
.
_decompress
(
"foo.tar"
)
def
test_decompress_remove_finished
(
self
):
def
test_decompress_remove_finished
(
self
):
...
@@ -161,21 +115,18 @@ class Tester(unittest.TestCase):
...
@@ -161,21 +115,18 @@ class Tester(unittest.TestCase):
utils
.
extract_archive
(
compressed
,
temp_dir
,
remove_finished
=
True
)
utils
.
extract_archive
(
compressed
,
temp_dir
,
remove_finished
=
True
)
self
.
assert
False
(
os
.
path
.
exists
(
compressed
)
)
assert
not
os
.
path
.
exists
(
compressed
)
def
test_extract_archive_defer_to_decompress
(
self
):
@
pytest
.
mark
.
parametrize
(
'extension'
,
[
".gz"
,
".xz"
])
@
pytest
.
mark
.
parametrize
(
'remove_finished'
,
[
True
,
False
])
def
test_extract_archive_defer_to_decompress
(
self
,
extension
,
remove_finished
,
mocker
):
filename
=
"foo"
filename
=
"foo"
for
ext
,
remove_finished
in
itertools
.
product
((
".gz"
,
".xz"
),
(
True
,
False
)):
file
=
f
"
{
filename
}{
extension
}
"
with
self
.
subTest
(
ext
=
ext
,
remove_finished
=
remove_finished
):
with
unittest
.
mock
.
patch
(
"torchvision.datasets.utils._decompress"
)
as
mock
:
mocked
=
mocker
.
patch
(
"torchvision.datasets.utils._decompress"
)
file
=
f
"
{
filename
}{
ext
}
"
utils
.
extract_archive
(
file
,
remove_finished
=
remove_finished
)
utils
.
extract_archive
(
file
,
remove_finished
=
remove_finished
)
mock
.
assert_called_once
()
mocked
.
assert_called_once_with
(
file
,
filename
,
remove_finished
=
remove_finished
)
self
.
assertEqual
(
call_args_to_kwargs_only
(
mock
.
call_args
,
utils
.
_decompress
),
dict
(
from_path
=
file
,
to_path
=
filename
,
remove_finished
=
remove_finished
),
)
def
test_extract_zip
(
self
):
def
test_extract_zip
(
self
):
def
create_archive
(
root
,
content
=
"this is the content"
):
def
create_archive
(
root
,
content
=
"this is the content"
):
...
@@ -192,41 +143,18 @@ class Tester(unittest.TestCase):
...
@@ -192,41 +143,18 @@ class Tester(unittest.TestCase):
utils
.
extract_archive
(
archive
,
temp_dir
)
utils
.
extract_archive
(
archive
,
temp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
file
))
assert
os
.
path
.
exists
(
file
)
with
open
(
file
,
"r"
)
as
fh
:
self
.
assertEqual
(
fh
.
read
(),
content
)
def
test_extract_tar
(
self
):
def
create_archive
(
root
,
ext
,
mode
,
content
=
"this is the content"
):
src
=
os
.
path
.
join
(
root
,
"src.txt"
)
dst
=
os
.
path
.
join
(
root
,
"dst.txt"
)
archive
=
os
.
path
.
join
(
root
,
f
"archive
{
ext
}
"
)
with
open
(
src
,
"w"
)
as
fh
:
fh
.
write
(
content
)
with
tarfile
.
open
(
archive
,
mode
=
mode
)
as
fh
:
fh
.
add
(
src
,
arcname
=
os
.
path
.
basename
(
dst
))
return
archive
,
dst
,
content
for
ext
,
mode
in
zip
([
'.tar'
,
'.tar.gz'
,
'.tgz'
],
[
'w'
,
'w:gz'
,
'w:gz'
]):
with
get_tmp_dir
()
as
temp_dir
:
archive
,
file
,
content
=
create_archive
(
temp_dir
,
ext
,
mode
)
utils
.
extract_archive
(
archive
,
temp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
file
))
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
,
"r"
)
as
fh
:
self
.
assertEqual
(
fh
.
read
()
,
content
)
assert
fh
.
read
()
==
content
def
test_extract_tar_xz
(
self
):
@
pytest
.
mark
.
parametrize
(
'extension, mode'
,
[
def
create_archive
(
root
,
ext
,
mode
,
content
=
"this is the content"
):
(
'.tar'
,
'w'
),
(
'.tar.gz'
,
'w:gz'
),
(
'.tgz'
,
'w:gz'
),
(
'.tar.xz'
,
'w:xz'
)])
def
test_extract_tar
(
self
,
extension
,
mode
):
def
create_archive
(
root
,
extension
,
mode
,
content
=
"this is the content"
):
src
=
os
.
path
.
join
(
root
,
"src.txt"
)
src
=
os
.
path
.
join
(
root
,
"src.txt"
)
dst
=
os
.
path
.
join
(
root
,
"dst.txt"
)
dst
=
os
.
path
.
join
(
root
,
"dst.txt"
)
archive
=
os
.
path
.
join
(
root
,
f
"archive
{
ext
}
"
)
archive
=
os
.
path
.
join
(
root
,
f
"archive
{
ext
ension
}
"
)
with
open
(
src
,
"w"
)
as
fh
:
with
open
(
src
,
"w"
)
as
fh
:
fh
.
write
(
content
)
fh
.
write
(
content
)
...
@@ -236,22 +164,21 @@ class Tester(unittest.TestCase):
...
@@ -236,22 +164,21 @@ class Tester(unittest.TestCase):
return
archive
,
dst
,
content
return
archive
,
dst
,
content
for
ext
,
mode
in
zip
([
'.tar.xz'
],
[
'w:xz'
]):
with
get_tmp_dir
()
as
temp_dir
:
with
get_tmp_dir
()
as
temp_dir
:
archive
,
file
,
content
=
create_archive
(
temp_dir
,
ext
,
mode
)
archive
,
file
,
content
=
create_archive
(
temp_dir
,
ext
ension
,
mode
)
utils
.
extract_archive
(
archive
,
temp_dir
)
utils
.
extract_archive
(
archive
,
temp_dir
)
self
.
assert
True
(
os
.
path
.
exists
(
file
)
)
assert
os
.
path
.
exists
(
file
)
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
,
"r"
)
as
fh
:
self
.
assertEqual
(
fh
.
read
()
,
content
)
assert
fh
.
read
()
==
content
def
test_verify_str_arg
(
self
):
def
test_verify_str_arg
(
self
):
self
.
assert
Equal
(
"a"
,
utils
.
verify_str_arg
(
"a"
,
"arg"
,
(
"a"
,))
)
assert
"a"
==
utils
.
verify_str_arg
(
"a"
,
"arg"
,
(
"a"
,))
self
.
assertR
aises
(
ValueError
,
utils
.
verify_str_arg
,
0
,
(
"a"
,),
"arg"
)
pytest
.
r
aises
(
ValueError
,
utils
.
verify_str_arg
,
0
,
(
"a"
,),
"arg"
)
self
.
assertR
aises
(
ValueError
,
utils
.
verify_str_arg
,
"b"
,
(
"a"
,),
"arg"
)
pytest
.
r
aises
(
ValueError
,
utils
.
verify_str_arg
,
"b"
,
(
"a"
,),
"arg"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment