Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1cbcb2b1
Unverified
Commit
1cbcb2b1
authored
May 27, 2021
by
Zhiqiang Wang
Committed by
GitHub
May 27, 2021
Browse files
Port test/test_image.py to pytest (#3930)
parent
4c563846
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
245 additions
and
207 deletions
+245
-207
test/test_image.py
test/test_image.py
+245
-207
No files found.
test/test_image.py
View file @
1cbcb2b1
...
@@ -2,7 +2,6 @@ import glob
...
@@ -2,7 +2,6 @@ import glob
import
io
import
io
import
os
import
os
import
sys
import
sys
import
unittest
from
pathlib
import
Path
from
pathlib
import
Path
import
pytest
import
pytest
...
@@ -54,176 +53,209 @@ def normalize_dimensions(img_pil):
...
@@ -54,176 +53,209 @@ def normalize_dimensions(img_pil):
return
img_pil
return
img_pil
class
ImageTester
(
unittest
.
TestCase
):
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
def
test_decode_jpeg
(
self
):
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
conversion
=
[(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"RGB"
,
ImageReadMode
.
RGB
)]
for
jpeg_path
in
get_images
(
IMAGE_ROOT
,
".jpg"
)
for
img_path
in
get_images
(
IMAGE_ROOT
,
".jpg"
):
])
for
pil_mode
,
mode
in
conversion
:
@
pytest
.
mark
.
parametrize
(
'pil_mode, mode'
,
[
with
Image
.
open
(
img_path
)
as
img
:
(
None
,
ImageReadMode
.
UNCHANGED
),
is_cmyk
=
img
.
mode
==
"CMYK"
(
"L"
,
ImageReadMode
.
GRAY
),
if
pil_mode
is
not
None
:
(
"RGB"
,
ImageReadMode
.
RGB
),
if
is_cmyk
:
])
# libjpeg does not support the conversion
def
test_decode_jpeg
(
img_path
,
pil_mode
,
mode
):
continue
img
=
img
.
convert
(
pil_mode
)
with
Image
.
open
(
img_path
)
as
img
:
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
is_cmyk
=
img
.
mode
==
"CMYK"
if
is_cmyk
:
if
pil_mode
is
not
None
:
# flip the colors to match libjpeg
if
is_cmyk
:
img_pil
=
255
-
img_pil
# libjpeg does not support the conversion
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
img_pil
=
normalize_dimensions
(
img_pil
)
img
=
img
.
convert
(
pil_mode
)
data
=
read_file
(
img_path
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
img_ljpeg
=
decode_image
(
data
,
mode
=
mode
)
if
is_cmyk
:
# flip the colors to match libjpeg
# Permit a small variation on pixel values to account for implementation
img_pil
=
255
-
img_pil
# differences between Pillow and LibJPEG.
abs_mean_diff
=
(
img_ljpeg
.
type
(
torch
.
float32
)
-
img_pil
).
abs
().
mean
().
item
()
img_pil
=
normalize_dimensions
(
img_pil
)
self
.
assertTrue
(
abs_mean_diff
<
2
)
data
=
read_file
(
img_path
)
img_ljpeg
=
decode_image
(
data
,
mode
=
mode
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Expected a non empty 1-dimensional tensor"
):
decode_jpeg
(
torch
.
empty
((
100
,
1
),
dtype
=
torch
.
uint8
))
# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Expected a torch.uint8 tensor"
):
abs_mean_diff
=
(
img_ljpeg
.
type
(
torch
.
float32
)
-
img_pil
).
abs
().
mean
().
item
()
decode_jpeg
(
torch
.
empty
((
100
,),
dtype
=
torch
.
float16
))
assert
abs_mean_diff
<
2
with
self
.
assertRaises
(
RuntimeError
):
decode_jpeg
(
torch
.
empty
((
100
),
dtype
=
torch
.
uint8
))
def
test_decode_jpeg_errors
():
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a non empty 1-dimensional tensor"
):
def
test_damaged_images
(
self
):
decode_jpeg
(
torch
.
empty
((
100
,
1
),
dtype
=
torch
.
uint8
))
# Test image with bad Huffman encoding (should not raise)
bad_huff
=
read_file
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'bad_huffman.jpg'
))
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a torch.uint8 tensor"
):
try
:
decode_jpeg
(
torch
.
empty
((
100
,),
dtype
=
torch
.
float16
))
_
=
decode_jpeg
(
bad_huff
)
except
RuntimeError
:
with
pytest
.
raises
(
RuntimeError
,
match
=
"Not a JPEG file"
):
self
.
assertTrue
(
False
)
decode_jpeg
(
torch
.
empty
((
100
),
dtype
=
torch
.
uint8
))
# Truncated images should raise an exception
truncated_images
=
glob
.
glob
(
def
test_decode_bad_huffman_images
():
os
.
path
.
join
(
DAMAGED_JPEG
,
'corrupt*.jpg'
))
# sanity check: make sure we can decode the bad Huffman encoding
for
image_path
in
truncated_images
:
bad_huff
=
read_file
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'bad_huffman.jpg'
))
data
=
read_file
(
image_path
)
decode_jpeg
(
bad_huff
)
with
self
.
assertRaises
(
RuntimeError
):
decode_jpeg
(
data
)
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
def
test_decode_png
(
self
):
pytest
.
param
(
truncated_image
,
id
=
_get_safe_image_name
(
truncated_image
))
conversion
=
[(
None
,
ImageReadMode
.
UNCHANGED
),
(
"L"
,
ImageReadMode
.
GRAY
),
(
"LA"
,
ImageReadMode
.
GRAY_ALPHA
),
for
truncated_image
in
glob
.
glob
(
os
.
path
.
join
(
DAMAGED_JPEG
,
'corrupt*.jpg'
))
(
"RGB"
,
ImageReadMode
.
RGB
),
(
"RGBA"
,
ImageReadMode
.
RGB_ALPHA
)]
])
for
img_path
in
get_images
(
FAKEDATA_DIR
,
".png"
):
def
test_damaged_corrupt_images
(
img_path
):
for
pil_mode
,
mode
in
conversion
:
# Truncated images should raise an exception
with
Image
.
open
(
img_path
)
as
img
:
data
=
read_file
(
img_path
)
if
pil_mode
is
not
None
:
if
'corrupt34'
in
img_path
:
img
=
img
.
convert
(
pil_mode
)
match_message
=
"Image is incomplete or truncated"
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
else
:
match_message
=
"Unsupported marker type"
img_pil
=
normalize_dimensions
(
img_pil
)
with
pytest
.
raises
(
RuntimeError
,
match
=
match_message
):
data
=
read_file
(
img_path
)
decode_jpeg
(
data
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
tol
=
0
if
conversion
is
None
else
1
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
self
.
assertTrue
(
img_lpng
.
allclose
(
img_pil
,
atol
=
tol
))
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
for
png_path
in
get_images
(
FAKEDATA_DIR
,
".png"
)
with
self
.
assertRaises
(
RuntimeError
):
])
decode_png
(
torch
.
empty
((),
dtype
=
torch
.
uint8
))
@
pytest
.
mark
.
parametrize
(
'pil_mode, mode'
,
[
with
self
.
assertRaises
(
RuntimeError
):
(
None
,
ImageReadMode
.
UNCHANGED
),
decode_png
(
torch
.
randint
(
3
,
5
,
(
300
,),
dtype
=
torch
.
uint8
))
(
"L"
,
ImageReadMode
.
GRAY
),
(
"LA"
,
ImageReadMode
.
GRAY_ALPHA
),
def
test_encode_png
(
self
):
(
"RGB"
,
ImageReadMode
.
RGB
),
for
img_path
in
get_images
(
IMAGE_DIR
,
'.png'
):
(
"RGBA"
,
ImageReadMode
.
RGB_ALPHA
),
pil_image
=
Image
.
open
(
img_path
)
])
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
def
test_decode_png
(
img_path
,
pil_mode
,
mode
):
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
png_buf
=
encode_png
(
img_pil
,
compression_level
=
6
)
with
Image
.
open
(
img_path
)
as
img
:
if
pil_mode
is
not
None
:
rec_img
=
Image
.
open
(
io
.
BytesIO
(
bytes
(
png_buf
.
tolist
())))
img
=
img
.
convert
(
pil_mode
)
rec_img
=
torch
.
from_numpy
(
np
.
array
(
rec_img
))
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
rec_img
=
rec_img
.
permute
(
2
,
0
,
1
)
img_pil
=
normalize_dimensions
(
img_pil
)
assert_equal
(
img_pil
,
rec_img
)
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Input tensor dtype should be uint8"
):
tol
=
0
if
pil_mode
is
None
else
1
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
float32
))
assert
img_lpng
.
allclose
(
img_pil
,
atol
=
tol
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Compression level should be between 0 and 9"
):
def
test_decode_png_errors
():
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
with
pytest
.
raises
(
RuntimeError
,
match
=
"Expected a non empty 1-dimensional tensor"
):
compression_level
=-
1
)
decode_png
(
torch
.
empty
((),
dtype
=
torch
.
uint8
))
with
pytest
.
raises
(
RuntimeError
,
match
=
"Content is not png"
):
with
self
.
assertRaisesRegex
(
decode_png
(
torch
.
randint
(
3
,
5
,
(
300
,),
dtype
=
torch
.
uint8
))
RuntimeError
,
"Compression level should be between 0 and 9"
):
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
compression_level
=
10
)
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
with
self
.
assertRaisesRegex
(
for
png_path
in
get_images
(
IMAGE_DIR
,
".png"
)
RuntimeError
,
"The number of channels should be 1 or 3, got: 5"
):
])
encode_png
(
torch
.
empty
((
5
,
100
,
100
),
dtype
=
torch
.
uint8
))
def
test_encode_png
(
img_path
):
pil_image
=
Image
.
open
(
img_path
)
def
test_write_png
(
self
):
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
with
get_tmp_dir
()
as
d
:
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
for
img_path
in
get_images
(
IMAGE_DIR
,
'.png'
):
png_buf
=
encode_png
(
img_pil
,
compression_level
=
6
)
pil_image
=
Image
.
open
(
img_path
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
rec_img
=
Image
.
open
(
io
.
BytesIO
(
bytes
(
png_buf
.
tolist
())))
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
rec_img
=
torch
.
from_numpy
(
np
.
array
(
rec_img
))
rec_img
=
rec_img
.
permute
(
2
,
0
,
1
)
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
torch_png
=
os
.
path
.
join
(
d
,
'{0}_torch.png'
.
format
(
filename
))
assert_equal
(
img_pil
,
rec_img
)
write_png
(
img_pil
,
torch_png
,
compression_level
=
6
)
saved_image
=
torch
.
from_numpy
(
np
.
array
(
Image
.
open
(
torch_png
)))
saved_image
=
saved_image
.
permute
(
2
,
0
,
1
)
def
test_encode_png_errors
():
with
pytest
.
raises
(
RuntimeError
,
match
=
"Input tensor dtype should be uint8"
):
assert_equal
(
img_pil
,
saved_image
)
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
float32
))
def
test_read_file
(
self
):
with
pytest
.
raises
(
RuntimeError
,
match
=
"Compression level should be between 0 and 9"
):
with
get_tmp_dir
()
as
d
:
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
compression_level
=-
1
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
with
open
(
fpath
,
'wb'
)
as
f
:
with
pytest
.
raises
(
RuntimeError
,
match
=
"Compression level should be between 0 and 9"
):
f
.
write
(
content
)
encode_png
(
torch
.
empty
((
3
,
100
,
100
),
dtype
=
torch
.
uint8
),
compression_level
=
10
)
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
with
pytest
.
raises
(
RuntimeError
,
match
=
"The number of channels should be 1 or 3, got: 5"
):
assert_equal
(
data
,
expected
)
encode_png
(
torch
.
empty
((
5
,
100
,
100
),
dtype
=
torch
.
uint8
))
os
.
unlink
(
fpath
)
with
self
.
assertRaisesRegex
(
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
RuntimeError
,
"No such file or directory: 'tst'"
):
pytest
.
param
(
png_path
,
id
=
_get_safe_image_name
(
png_path
))
read_file
(
'tst'
)
for
png_path
in
get_images
(
IMAGE_DIR
,
".png"
)
])
def
test_read_file_non_ascii
(
self
):
def
test_write_png
(
img_path
):
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
pil_image
=
Image
.
open
(
img_path
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
img_pil
=
torch
.
from_numpy
(
np
.
array
(
pil_image
))
with
open
(
fpath
,
'wb'
)
as
f
:
img_pil
=
img_pil
.
permute
(
2
,
0
,
1
)
f
.
write
(
content
)
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
data
=
read_file
(
fpath
)
torch_png
=
os
.
path
.
join
(
d
,
'{0}_torch.png'
.
format
(
filename
))
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
write_png
(
img_pil
,
torch_png
,
compression_level
=
6
)
assert_equal
(
data
,
expected
)
saved_image
=
torch
.
from_numpy
(
np
.
array
(
Image
.
open
(
torch_png
)))
os
.
unlink
(
fpath
)
saved_image
=
saved_image
.
permute
(
2
,
0
,
1
)
def
test_write_file
(
self
):
assert_equal
(
img_pil
,
saved_image
)
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
def
test_read_file
():
content_tensor
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
with
get_tmp_dir
()
as
d
:
write_file
(
fpath
,
content_tensor
)
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
with
open
(
fpath
,
'rb'
)
as
f
:
with
open
(
fpath
,
'wb'
)
as
f
:
saved_content
=
f
.
read
()
f
.
write
(
content
)
self
.
assertEqual
(
content
,
saved_content
)
os
.
unlink
(
fpath
)
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
def
test_write_file_non_ascii
(
self
):
os
.
unlink
(
fpath
)
with
get_tmp_dir
()
as
d
:
assert_equal
(
data
,
expected
)
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
with
pytest
.
raises
(
RuntimeError
,
match
=
"No such file or directory: 'tst'"
):
content_tensor
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
read_file
(
'tst'
)
write_file
(
fpath
,
content_tensor
)
with
open
(
fpath
,
'rb'
)
as
f
:
def
test_read_file_non_ascii
():
saved_content
=
f
.
read
()
with
get_tmp_dir
()
as
d
:
self
.
assertEqual
(
content
,
saved_content
)
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
os
.
unlink
(
fpath
)
fpath
=
os
.
path
.
join
(
d
,
fname
)
with
open
(
fpath
,
'wb'
)
as
f
:
f
.
write
(
content
)
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
os
.
unlink
(
fpath
)
assert_equal
(
data
,
expected
)
def
test_write_file
():
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'test1.bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
content_tensor
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
write_file
(
fpath
,
content_tensor
)
with
open
(
fpath
,
'rb'
)
as
f
:
saved_content
=
f
.
read
()
os
.
unlink
(
fpath
)
assert
content
==
saved_content
def
test_write_file_non_ascii
():
with
get_tmp_dir
()
as
d
:
fname
,
content
=
'日本語(Japanese).bin'
,
b
'TorchVision
\211\n
'
fpath
=
os
.
path
.
join
(
d
,
fname
)
content_tensor
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
write_file
(
fpath
,
content_tensor
)
with
open
(
fpath
,
'rb'
)
as
f
:
saved_content
=
f
.
read
()
os
.
unlink
(
fpath
)
assert
content
==
saved_content
@
needs_cuda
@
needs_cuda
...
@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
...
@@ -236,14 +268,14 @@ class ImageTester(unittest.TestCase):
def
test_decode_jpeg_cuda
(
mode
,
img_path
,
scripted
):
def
test_decode_jpeg_cuda
(
mode
,
img_path
,
scripted
):
if
'cmyk'
in
img_path
:
if
'cmyk'
in
img_path
:
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
pytest
.
xfail
(
"Decoding a CMYK jpeg isn't supported"
)
tester
=
ImageTester
()
data
=
read_file
(
img_path
)
data
=
read_file
(
img_path
)
img
=
decode_image
(
data
,
mode
=
mode
)
img
=
decode_image
(
data
,
mode
=
mode
)
f
=
torch
.
jit
.
script
(
decode_jpeg
)
if
scripted
else
decode_jpeg
f
=
torch
.
jit
.
script
(
decode_jpeg
)
if
scripted
else
decode_jpeg
img_nvjpeg
=
f
(
data
,
mode
=
mode
,
device
=
'cuda'
)
img_nvjpeg
=
f
(
data
,
mode
=
mode
,
device
=
'cuda'
)
# Some difference expected between jpeg implementations
# Some difference expected between jpeg implementations
tester
.
assert
True
(
(
img
.
float
()
-
img_nvjpeg
.
cpu
().
float
()).
abs
().
mean
()
<
2
)
assert
(
img
.
float
()
-
img_nvjpeg
.
cpu
().
float
()).
abs
().
mean
()
<
2
@
needs_cuda
@
needs_cuda
...
@@ -304,7 +336,11 @@ def _collect_if(cond):
...
@@ -304,7 +336,11 @@ def _collect_if(cond):
@
cpu_only
@
cpu_only
@
_collect_if
(
cond
=
IS_WINDOWS
)
@
_collect_if
(
cond
=
IS_WINDOWS
)
def
test_encode_jpeg_windows
():
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
for
jpeg_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
)
])
def
test_encode_jpeg_windows
(
img_path
):
# This test is *wrong*.
# This test is *wrong*.
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# It compares a torchvision-encoded jpeg with a PIL-encoded jpeg, but it
# starts encoding the torchvision version from an image that comes from
# starts encoding the torchvision version from an image that comes from
...
@@ -315,48 +351,50 @@ def test_encode_jpeg_windows():
...
@@ -315,48 +351,50 @@ def test_encode_jpeg_windows():
# these more correct tests fail on windows (probably because of a difference
# these more correct tests fail on windows (probably because of a difference
# in libjpeg) between torchvision and PIL.
# in libjpeg) between torchvision and PIL.
# FIXME: make the correct tests pass on windows and remove this.
# FIXME: make the correct tests pass on windows and remove this.
for
img_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
):
dirname
=
os
.
path
.
dirname
(
img_path
)
dirname
=
os
.
path
.
dirname
(
img_path
)
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
write_folder
=
os
.
path
.
join
(
dirname
,
'jpeg_write'
)
write_folder
=
os
.
path
.
join
(
dirname
,
'jpeg_write'
)
expected_file
=
os
.
path
.
join
(
expected_file
=
os
.
path
.
join
(
write_folder
,
'{0}_pil.jpg'
.
format
(
filename
))
write_folder
,
'{0}_pil.jpg'
.
format
(
filename
))
img
=
decode_jpeg
(
read_file
(
img_path
))
img
=
decode_jpeg
(
read_file
(
img_path
))
with
open
(
expected_file
,
'rb'
)
as
f
:
with
open
(
expected_file
,
'rb'
)
as
f
:
pil_bytes
=
f
.
read
()
pil_bytes
=
f
.
read
()
pil_bytes
=
torch
.
as_tensor
(
list
(
pil_bytes
),
dtype
=
torch
.
uint8
)
pil_bytes
=
torch
.
as_tensor
(
list
(
pil_bytes
),
dtype
=
torch
.
uint8
)
for
src_img
in
[
img
,
img
.
contiguous
()]:
for
src_img
in
[
img
,
img
.
contiguous
()]:
# PIL sets jpeg quality to 75 by default
# PIL sets jpeg quality to 75 by default
jpeg_bytes
=
encode_jpeg
(
src_img
,
quality
=
75
)
jpeg_bytes
=
encode_jpeg
(
src_img
,
quality
=
75
)
assert_equal
(
jpeg_bytes
,
pil_bytes
)
assert_equal
(
jpeg_bytes
,
pil_bytes
)
@
cpu_only
@
cpu_only
@
_collect_if
(
cond
=
IS_WINDOWS
)
@
_collect_if
(
cond
=
IS_WINDOWS
)
def
test_write_jpeg_windows
():
@
pytest
.
mark
.
parametrize
(
'img_path'
,
[
pytest
.
param
(
jpeg_path
,
id
=
_get_safe_image_name
(
jpeg_path
))
for
jpeg_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
)
])
def
test_write_jpeg_windows
(
img_path
):
# FIXME: Remove this eventually, see test_encode_jpeg_windows
# FIXME: Remove this eventually, see test_encode_jpeg_windows
with
get_tmp_dir
()
as
d
:
with
get_tmp_dir
()
as
d
:
for
img_path
in
get_images
(
ENCODE_JPEG
,
".jpg"
):
data
=
read_file
(
img_path
)
data
=
read_file
(
img_path
)
img
=
decode_jpeg
(
data
)
img
=
decode_jpeg
(
data
)
basedir
=
os
.
path
.
dirname
(
img_path
)
basedir
=
os
.
path
.
dirname
(
img_path
)
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
filename
,
_
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))
torch_jpeg
=
os
.
path
.
join
(
torch_jpeg
=
os
.
path
.
join
(
d
,
'{0}_torch.jpg'
.
format
(
filename
))
d
,
'{0}_torch.jpg'
.
format
(
filename
))
pil_jpeg
=
os
.
path
.
join
(
pil_jpeg
=
os
.
path
.
join
(
basedir
,
'jpeg_write'
,
'{0}_pil.jpg'
.
format
(
filename
))
basedir
,
'jpeg_write'
,
'{0}_pil.jpg'
.
format
(
filename
))
write_jpeg
(
img
,
torch_jpeg
,
quality
=
75
)
write_jpeg
(
img
,
torch_jpeg
,
quality
=
75
)
with
open
(
torch_jpeg
,
'rb'
)
as
f
:
with
open
(
torch_jpeg
,
'rb'
)
as
f
:
torch_bytes
=
f
.
read
()
torch_bytes
=
f
.
read
()
with
open
(
pil_jpeg
,
'rb'
)
as
f
:
with
open
(
pil_jpeg
,
'rb'
)
as
f
:
pil_bytes
=
f
.
read
()
pil_bytes
=
f
.
read
()
assert_equal
(
torch_bytes
,
pil_bytes
)
assert_equal
(
torch_bytes
,
pil_bytes
)
@
cpu_only
@
cpu_only
...
@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
...
@@ -408,5 +446,5 @@ def test_write_jpeg(img_path):
assert_equal
(
torch_bytes
,
pil_bytes
)
assert_equal
(
torch_bytes
,
pil_bytes
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
unit
test
.
main
()
py
test
.
main
(
[
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment