Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1cbcb2b1
Unverified
Commit
1cbcb2b1
authored
May 27, 2021
by
Zhiqiang Wang
Committed by
GitHub
May 27, 2021
Browse files
Port test/test_image.py to pytest (#3930)
parent
4c563846
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
245 additions
and
207 deletions
+245
-207
test/test_image.py
test/test_image.py
+245
-207
No files found.
test/test_image.py
View file @
1cbcb2b1
...
@@ -2,7 +2,6 @@ import glob
...
@@ -2,7 +2,6 @@ import glob
import
io
import
io
import
os
import
os
import
sys
import
sys
import
unittest
from
pathlib
import
Path
from
pathlib
import
Path
import
pytest
import
pytest
...
@@ -54,17 +53,23 @@ def normalize_dimensions(img_pil):
...
@@ -54,17 +53,23 @@ def normalize_dimensions(img_pil):
return
img_pil
return
img_pil
class
ImageTester
(
unittest
.
TestCase
):
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
def
test_decode_jpeg
(
self
):
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
conversion
=
[(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"RGB"
,
ImageReadMode
.
RGB
)]
for
jpeg_path
in
get_images
(
IMAGE_ROOT
,
".jpg"
)
for
img_path
in
get_images
(
IMAGE_ROOT
,
".jpg"
):
])
for
pil_mode
,
mode
in
conversion
:
@
pytest
.
mark
.
parametrize
(
'pil_mode, mode'
,
[
(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"RGB"
,
ImageReadMode
.
RGB
),
])
def
test_decode_jpeg
(
img_path
,
pil_mode
,
mode
):
with
Image
.
open
(
img_path
)
as
img
:
with
Image
.
open
(
img_path
)
as
img
:
is_cmyk
=
img
.
mode
==
"CMYK"
is_cmyk
=
img
.
mode
==
"CMYK"
if
pil_mode
is
not
None
:
if
pil_mode
is
not
None
:
if
is_cmyk
:
if
is_cmyk
:
# libjpeg does not support the conversion
# libjpeg does not support the conversion
continue
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
img
=
img
.
convert
(
pil_mode
)
img
=
img
.
convert
(
pil_mode
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
if
is_cmyk
:
if
is_cmyk
:
...
@@ -78,38 +83,54 @@ class ImageTester(unittest.TestCase):
...
@@ -78,38 +83,54 @@ class ImageTester(unittest.TestCase):
# Permit a small variation on pixel values to account for implementation
# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
# differences between Pillow and LibJPEG.
abs_mean_diff
=
(
img_ljpeg
.
type
(
torch
.
float32
)
-
img_pil
).
abs
().
mean
().
item
()
abs_mean_diff
=
(
img_ljpeg
.
type
(
torch
.
float32
)
-
img_pil
).
abs
().
mean
().
item
()
self
.
assert
True
(
abs_mean_diff
<
2
)
assert
abs_mean_diff
<
2
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Expected a non empty 1-dimensional tensor"
):
def
test_decode_jpeg_errors
():
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a non empty 1-dimensional tensor"
):
decode_jpeg
(
torch
.
empty
((
100
,
1
),
dtype
=
torch
.
uint8
))
decode_jpeg
(
torch
.
empty
((
100
,
1
),
dtype
=
torch
.
uint8
))
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Expected a torch.uint8 tensor"
):
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a torch.uint8 tensor"
):
decode_jpeg
(
torch
.
empty
((
100
,),
dtype
=
torch
.
float16
))
decode_jpeg
(
torch
.
empty
((
100
,),
dtype
=
torch
.
float16
))
with
self
.
assertR
aises
(
RuntimeError
):
with
pytest
.
r
aises
(
RuntimeError
,
match
=
"Not a JPEG file"
):
decode_jpeg
(
torch
.
empty
((
100
),
dtype
=
torch
.
uint8
))
decode_jpeg
(
torch
.
empty
((
100
),
dtype
=
torch
.
uint8
))
def
test_damaged_images
(
self
):
# Test image with bad Huffman encoding (should not raise)
def
test_decode_bad_huffman_images
():
# sanity check: make sure we can decode the bad Huffman encoding
bad_huff
=
read_file
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'bad_huffman.jpg'
))
bad_huff
=
read_file
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'bad_huffman.jpg'
))
try
:
decode_jpeg
(
bad_huff
)
_
=
decode_jpeg
(
bad_huff
)
except
RuntimeError
:
self
.
assertTrue
(
False
)
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
truncated_image
,
id
=
_get_safe_image_name
(
truncated_image
))
for
truncated_image
in
glob
.
glob
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'corrupt*.jpg'
))
])
def
test_damaged_corrupt_images
(
img_path
):
# Truncated images should raise an exception
# Truncated images should raise an exception
truncated_images
=
glob
.
glob
(
data
=
read_file
(
img_path
)
os
.
path
.
join
(
DAMAGED_JPEG
,
'corrupt*.jpg'
))
if
'corrupt34'
in
img_path
:
for
image_path
in
truncated_images
:
match_message
=
"Image is incomplete or truncated"
data
=
read_file
(
image_path
)
else
:
with
self
.
assertRaises
(
RuntimeError
):
match_message
=
"Unsupported marker type"
with
pytest
.
raises
(
RuntimeError
,
match
=
match_message
):
decode_jpeg
(
data
)
decode_jpeg
(
data
)
def
test_decode_png
(
self
):
conversion
=
[(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"LA"
,
ImageReadMode
.
GRAY_ALPHA
),
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
(
"RGB"
,
ImageReadMode
.
RGB
),
(
"RGBA"
,
ImageReadMode
.
RGB_ALPHA
)]
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
for
img_path
in
get_images
(
FAKEDATA_DIR
,
".png"
):
for
png_path
in
get_images
(
FAKEDATA_DIR
,
".png"
)
for
pil_mode
,
mode
in
conversion
:
])
@
pytest
.
mark
.
parametrize
(
'pil_mode, mode'
,
[
(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"LA"
,
ImageReadMode
.
GRAY_ALPHA
),
(
"RGB"
,
ImageReadMode
.
RGB
),
(
"RGBA"
,
ImageReadMode
.
RGB_ALPHA
),
])
def
test_decode_png
(
img_path
,
pil_mode
,
mode
):
with
Image
.
open
(
img_path
)
as
img
:
with
Image
.
open
(
img_path
)
as
img
:
if
pil_mode
is
not
None
:
if
pil_mode
is
not
None
:
img
=
img
.
convert
(
pil_mode
)
img
=
img
.
convert
(
pil_mode
)
...
@@ -119,16 +140,22 @@ class ImageTester(unittest.TestCase):
...
@@ -119,16 +140,22 @@ class ImageTester(unittest.TestCase):
data
=
read_file
(
img_path
)
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
tol
=
0
if
conversion
is
None
else
1
tol
=
0
if
pil_mode
is
None
else
1
self
.
assertTrue
(
img_lpng
.
allclose
(
img_pil
,
atol
=
tol
))
assert
img_lpng
.
allclose
(
img_pil
,
atol
=
tol
)
with
self
.
assertRaises
(
RuntimeError
):
def
test_decode_png_errors
():
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a non empty 1-dimensional tensor"
):
decode_png
(
torch
.
empty
((),
dtype
=
torch
.
uint8
))
decode_png
(
torch
.
empty
((),
dtype
=
torch
.
uint8
))
with
self
.
assertR
aises
(
RuntimeError
):
with
pytest
.
r
aises
(
RuntimeError
,
match
=
"Content is not png"
):
decode_png
(
torch
.
randint
(
3
,
5
,
(
300
,),
dtype
=
torch
.
uint8
))
decode_png
(
torch
.
randint
(
3
,
5
,
(
300
,),
dtype
=
torch
.
uint8
))
def
test_encode_png
(
self
):
for
img_path
in
get_images
(
IMAGE_DIR
,
'.png'
):
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
for
png_path
in
get_images
(
IMAGE_DIR
,
".png"
)
])
def
test_encode_png
(
img_path
):
pil_image
=
Image
.
open
(
img_path
)
pil_image
=
Image
.
open
(
img_path
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
...
@@ -140,27 +167,29 @@ class ImageTester(unittest.TestCase):
...
@@ -140,27 +167,29 @@ class ImageTester(unittest.TestCase):
assert_equal
(
img_pil
,
rec_img
)
assert_equal
(
img_pil
,
rec_img
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Input tensor dtype should be uint8"
):
def
test_encode_png_errors
():
with
pytest
.
raises
(
RuntimeError
,
match
=
"Input tensor dtype should be uint8"
):
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
float32
))
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
float32
))
with
self
.
assertRaisesRegex
(
with
pytest
.
raises
(
RuntimeError
,
match
=
"Compression level should be between 0 and 9"
):
RuntimeError
,
"Compression level should be between 0 and 9"
):
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
compression_level
=-
1
)
compression_level
=-
1
)
with
self
.
assertRaisesRegex
(
with
pytest
.
raises
(
RuntimeError
,
match
=
"Compression level should be between 0 and 9"
):
RuntimeError
,
"Compression level should be between 0 and 9"
):
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
compression_level
=
10
)
compression_level
=
10
)
with
self
.
assertRaisesRegex
(
with
pytest
.
raises
(
RuntimeError
,
match
=
"The number of channels should be 1 or 3, got: 5"
):
RuntimeError
,
"The number of channels should be 1 or 3, got: 5"
):
encode_png
(
torch
.
empty
((
5
,
100
,
100
),
dtype
=
torch
.
uint8
))
encode_png
(
torch
.
empty
((
5
,
100
,
100
),
dtype
=
torch
.
uint8
))
def
test_write_png
(
self
):
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
for
png_path
in
get_images
(
IMAGE_DIR
,
".png"
)
])
def
test_write_png
(
img_path
):
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
for
img_path
in
get_images
(
IMAGE_DIR
,
'.png'
):
pil_image
=
Image
.
open
(
img_path
)
pil_image
=
Image
.
open
(
img_path
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
...
@@ -173,7 +202,8 @@ class ImageTester(unittest.TestCase):
...
@@ -173,7 +202,8 @@ class ImageTester(unittest.TestCase):
assert_equal
(
img_pil
,
saved_image
)
assert_equal
(
img_pil
,
saved_image
)
def
test_read_file
(
self
):
def
test_read_file
():
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
...
@@ -182,14 +212,14 @@ class ImageTester(unittest.TestCase):
...
@@ -182,14 +212,14 @@ class ImageTester(unittest.TestCase):
data
=
read_file
(
fpath
)
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
assert_equal
(
data
,
expected
)
os
.
unlink
(
fpath
)
os
.
unlink
(
fpath
)
assert_equal
(
data
,
expected
)
with
self
.
assertRaisesRegex
(
with
pytest
.
raises
(
RuntimeError
,
match
=
"No such file or directory: 'tst'"
):
RuntimeError
,
"No such file or directory: 'tst'"
):
read_file
(
'tst'
)
read_file
(
'tst'
)
def
test_read_file_non_ascii
(
self
):
def
test_read_file_non_ascii
():
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
...
@@ -198,10 +228,11 @@ class ImageTester(unittest.TestCase):
...
@@ -198,10 +228,11 @@ class ImageTester(unittest.TestCase):
data
=
read_file
(
fpath
)
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
assert_equal
(
data
,
expected
)
os
.
unlink
(
fpath
)
os
.
unlink
(
fpath
)
assert_equal
(
data
,
expected
)
def
test_write_file
(
self
):
def
test_write_file
():
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
...
@@ -210,10 +241,11 @@ class ImageTester(unittest.TestCase):
...
@@ -210,10 +241,11 @@ class ImageTester(unittest.TestCase):
with
open
(
fpath
,
'rb'
)
as
f
:
with
open
(
fpath
,
'rb'
)
as
f
:
saved_content
=
f
.
read
()
saved_content
=
f
.
read
()
self
.
assertEqual
(
content
,
saved_content
)
os
.
unlink
(
fpath
)
os
.
unlink
(
fpath
)
assert
content
==
saved_content
def
test_write_file_non_ascii
(
self
):
def
test_write_file_non_ascii
():
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
...
@@ -222,8 +254,8 @@ class ImageTester(unittest.TestCase):
...
@@ -222,8 +254,8 @@ class ImageTester(unittest.TestCase):
with
open
(
fpath
,
'rb'
)
as
f
:
with
open
(
fpath
,
'rb'
)
as
f
:
saved_content
=
f
.
read
()
saved_content
=
f
.
read
()
self
.
assertEqual
(
content
,
saved_content
)
os
.
unlink
(
fpath
)
os
.
unlink
(
fpath
)
assert
content
==
saved_content
@
needs_cuda
@
needs_cuda
...
@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
...
@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
def
test_decode_jpeg_cuda
(
mode
,
img_path
,
scripted
):
def
test_decode_jpeg_cuda
(
mode
,
img_path
,
scripted
):
if
'cmyk'
in
img_path
:
if
'cmyk'
in
img_path
:
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
tester
=
ImageTester
()
data
=
read_file
(
img_path
)
data
=
read_file
(
img_path
)
img
=
decode_image
(
data
,
mode
=
mode
)
img
=
decode_image
(
data
,
mode
=
mode
)
f
=
torch
.
jit
.
script
(
decode_jpeg
)
if
scripted
else
decode_jpeg
f
=
torch
.
jit
.
script
(
decode_jpeg
)
if
scripted
else
decode_jpeg
img_nvjpeg
=
f
(
data
,
mode
=
mode
,
device
=
'cuda'
)
img_nvjpeg
=
f
(
data
,
mode
=
mode
,
device
=
'cuda'
)
# Some difference expected between jpeg implementations
# Some difference expected between jpeg implementations
tester
.
assert
True
(
(
img
.
float
()
-
img_nvjpeg
.
cpu
().
float
()).
abs
().
mean
()
<
2
)
assert
(
img
.
float
()
-
img_nvjpeg
.
cpu
().
float
()).
abs
().
mean
()
<
2
@
needs_cuda
@
needs_cuda
...
@@ -304,7 +336,11 @@ def _collect_if(cond):
...
@@ -304,7 +336,11 @@ def _collect_if(cond):
@
cpu_only
@
cpu_only
@
_collect_if
(
cond
=
IS_WINDOWS
)
@
_collect_if
(
cond
=
IS_WINDOWS
)
def
test_encode_jpeg_windows
():
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
for
jpeg_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
)
])
def
test_encode_jpeg_windows
(
img_path
):
# This test is *wrong*.
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
# starts encoding the torchvision version from an image that comes from
...
@@ -315,7 +351,6 @@ def test_encode_jpeg_windows():
...
@@ -315,7 +351,6 @@ def test_encode_jpeg_windows():
# these more correct tests fail on windows (probably because of a difference
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
# FIXME: make the correct tests pass on windows and remove this.
for
img_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
):
dirname
=
os
.
path
.
dirname
(
img_path
)
dirname
=
os
.
path
.
dirname
(
img_path
)
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
write_folder
=
os
.
path
.
join
(
dirname
,
'jpeg_write'
)
write_folder
=
os
.
path
.
join
(
dirname
,
'jpeg_write'
)
...
@@ -334,10 +369,13 @@ def test_encode_jpeg_windows():
...
@@ -334,10 +369,13 @@ def test_encode_jpeg_windows():
@
cpu_only
@
cpu_only
@
_collect_if
(
cond
=
IS_WINDOWS
)
@
_collect_if
(
cond
=
IS_WINDOWS
)
def
test_write_jpeg_windows
():
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
for
jpeg_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
)
])
def
test_write_jpeg_windows
(
img_path
):
# FIXME: Remove this eventually, see test_encode_jpeg_windows
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
for
img_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
):
data
=
read_file
(
img_path
)
data
=
read_file
(
img_path
)
img
=
decode_jpeg
(
data
)
img
=
decode_jpeg
(
data
)
...
@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
...
@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
assert_equal
(
torch_bytes
,
pil_bytes
)
assert_equal
(
torch_bytes
,
pil_bytes
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment