Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
195bb86e
Unverified
Commit
195bb86e
authored
May 21, 2021
by
Nicolas Hug
Committed by
GitHub
May 21, 2021
Browse files
Use torch.testing.assert_close in test_image.py (#3877)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
05e061f5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
test/test_image.py
test/test_image.py
+6
-5
No files found.
test/test_image.py
View file @
195bb86e
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
torch
from
PIL
import
Image
from
common_utils
import
get_tmp_dir
,
needs_cuda
from
_assert_utils
import
assert_equal
from
torchvision.io.image
import
(
decode_png
,
decode_jpeg
,
encode_jpeg
,
write_jpeg
,
decode_image
,
read_file
,
...
...
@@ -107,7 +108,7 @@ class ImageTester(unittest.TestCase):
for
src_img
in
[
img
,
img
.
contiguous
()]:
# PIL sets jpeg quality to 75 by default
jpeg_bytes
=
encode_jpeg
(
src_img
,
quality
=
75
)
self
.
assert
True
(
jpeg_bytes
.
equal
(
pil_bytes
)
)
assert
_equal
(
jpeg_bytes
,
pil_bytes
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Input tensor dtype should be uint8"
):
...
...
@@ -191,7 +192,7 @@ class ImageTester(unittest.TestCase):
rec_img
=
torch
.
from_numpy
(
np
.
array
(
rec_img
))
rec_img
=
rec_img
.
permute
(
2
,
0
,
1
)
self
.
assert
True
(
img_pil
.
equal
(
rec_img
)
)
assert
_equal
(
img_pil
,
rec_img
)
with
self
.
assertRaisesRegex
(
RuntimeError
,
"Input tensor dtype should be uint8"
):
...
...
@@ -224,7 +225,7 @@ class ImageTester(unittest.TestCase):
saved_image
=
torch
.
from_numpy
(
np
.
array
(
Image
.
open
(
torch_png
)))
saved_image
=
saved_image
.
permute
(
2
,
0
,
1
)
self
.
assert
True
(
img_pil
.
equal
(
saved_image
)
)
assert
_equal
(
img_pil
,
saved_image
)
def
test_read_file
(
self
):
with
get_tmp_dir
()
as
d
:
...
...
@@ -235,7 +236,7 @@ class ImageTester(unittest.TestCase):
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
self
.
assert
True
(
data
.
equal
(
expected
)
)
assert
_equal
(
data
,
expected
)
os
.
unlink
(
fpath
)
with
self
.
assertRaisesRegex
(
...
...
@@ -251,7 +252,7 @@ class ImageTester(unittest.TestCase):
data
=
read_file
(
fpath
)
expected
=
torch
.
tensor
(
list
(
content
),
dtype
=
torch
.
uint8
)
self
.
assert
True
(
data
.
equal
(
expected
)
)
assert
_equal
(
data
,
expected
)
os
.
unlink
(
fpath
)
def
test_write_file
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment