Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
31ee79e6
Unverified
Commit
31ee79e6
authored
May 21, 2021
by
Nicolas Hug
Committed by
GitHub
May 21, 2021
Browse files
Use torch.testing.assert_close in test_utils (#3887)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
b2f188eb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
19 deletions
+17
-19
test/test_utils.py
test/test_utils.py
+17
-19
No files found.
test/test_utils.py
View file @
31ee79e6
...
@@ -9,6 +9,7 @@ import unittest
...
@@ -9,6 +9,7 @@ import unittest
from
io
import
BytesIO
from
io
import
BytesIO
import
torchvision.transforms.functional
as
F
import
torchvision.transforms.functional
as
F
from
PIL
import
Image
,
__version__
as
PILLOW_VERSION
,
ImageColor
from
PIL
import
Image
,
__version__
as
PILLOW_VERSION
,
ImageColor
from
_assert_utils
import
assert_equal
PILLOW_VERSION
=
tuple
(
int
(
x
)
for
x
in
PILLOW_VERSION
.
split
(
'.'
))
PILLOW_VERSION
=
tuple
(
int
(
x
)
for
x
in
PILLOW_VERSION
.
split
(
'.'
))
...
@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
...
@@ -48,13 +49,13 @@ class Tester(unittest.TestCase):
t_clone
=
t
.
clone
()
t_clone
=
t
.
clone
()
utils
.
make_grid
(
t
,
normalize
=
False
)
utils
.
make_grid
(
t
,
normalize
=
False
)
self
.
assert
True
(
torch
.
equal
(
t
,
t_clone
)
,
'make_grid modified tensor in-place'
)
assert
_
equal
(
t
,
t_clone
,
msg
=
'make_grid modified tensor in-place'
)
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
False
)
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
False
)
self
.
assert
True
(
torch
.
equal
(
t
,
t_clone
)
,
'make_grid modified tensor in-place'
)
assert
_
equal
(
t
,
t_clone
,
msg
=
'make_grid modified tensor in-place'
)
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
True
)
utils
.
make_grid
(
t
,
normalize
=
True
,
scale_each
=
True
)
self
.
assert
True
(
torch
.
equal
(
t
,
t_clone
)
,
'make_grid modified tensor in-place'
)
assert
_
equal
(
t
,
t_clone
,
msg
=
'make_grid modified tensor in-place'
)
def
test_normalize_in_make_grid
(
self
):
def
test_normalize_in_make_grid
(
self
):
t
=
torch
.
rand
(
5
,
3
,
10
,
10
)
*
255
t
=
torch
.
rand
(
5
,
3
,
10
,
10
)
*
255
...
@@ -70,8 +71,8 @@ class Tester(unittest.TestCase):
...
@@ -70,8 +71,8 @@ class Tester(unittest.TestCase):
rounded_grid_max
=
torch
.
round
(
grid_max
*
10
**
n_digits
)
/
(
10
**
n_digits
)
rounded_grid_max
=
torch
.
round
(
grid_max
*
10
**
n_digits
)
/
(
10
**
n_digits
)
rounded_grid_min
=
torch
.
round
(
grid_min
*
10
**
n_digits
)
/
(
10
**
n_digits
)
rounded_grid_min
=
torch
.
round
(
grid_min
*
10
**
n_digits
)
/
(
10
**
n_digits
)
self
.
assert
True
(
torch
.
equal
(
norm_max
,
rounded_grid_max
)
,
'Normalized max is not equal to 1'
)
assert
_
equal
(
norm_max
,
rounded_grid_max
,
msg
=
'Normalized max is not equal to 1'
)
self
.
assert
True
(
torch
.
equal
(
norm_min
,
rounded_grid_min
)
,
'Normalized min is not equal to 0'
)
assert
_
equal
(
norm_min
,
rounded_grid_min
,
msg
=
'Normalized min is not equal to 0'
)
@
unittest
.
skipIf
(
sys
.
platform
in
(
'win32'
,
'cygwin'
),
'temporarily disabled on Windows'
)
@
unittest
.
skipIf
(
sys
.
platform
in
(
'win32'
,
'cygwin'
),
'temporarily disabled on Windows'
)
def
test_save_image
(
self
):
def
test_save_image
(
self
):
...
@@ -96,8 +97,7 @@ class Tester(unittest.TestCase):
...
@@ -96,8 +97,7 @@ class Tester(unittest.TestCase):
fp
=
BytesIO
()
fp
=
BytesIO
()
utils
.
save_image
(
t
,
fp
,
format
=
'png'
)
utils
.
save_image
(
t
,
fp
,
format
=
'png'
)
img_bytes
=
Image
.
open
(
fp
)
img_bytes
=
Image
.
open
(
fp
)
self
.
assertTrue
(
torch
.
equal
(
F
.
to_tensor
(
img_orig
),
F
.
to_tensor
(
img_bytes
)),
assert_equal
(
F
.
to_tensor
(
img_orig
),
F
.
to_tensor
(
img_bytes
),
msg
=
'Image not stored in file object'
)
'Image not stored in file object'
)
@
unittest
.
skipIf
(
sys
.
platform
in
(
'win32'
,
'cygwin'
),
'temporarily disabled on Windows'
)
@
unittest
.
skipIf
(
sys
.
platform
in
(
'win32'
,
'cygwin'
),
'temporarily disabled on Windows'
)
def
test_save_image_single_pixel_file_object
(
self
):
def
test_save_image_single_pixel_file_object
(
self
):
...
@@ -108,8 +108,7 @@ class Tester(unittest.TestCase):
...
@@ -108,8 +108,7 @@ class Tester(unittest.TestCase):
fp
=
BytesIO
()
fp
=
BytesIO
()
utils
.
save_image
(
t
,
fp
,
format
=
'png'
)
utils
.
save_image
(
t
,
fp
,
format
=
'png'
)
img_bytes
=
Image
.
open
(
fp
)
img_bytes
=
Image
.
open
(
fp
)
self
.
assertTrue
(
torch
.
equal
(
F
.
to_tensor
(
img_orig
),
F
.
to_tensor
(
img_bytes
)),
assert_equal
(
F
.
to_tensor
(
img_orig
),
F
.
to_tensor
(
img_bytes
),
msg
=
'Image not stored in file object'
)
'Pixel Image not stored in file object'
)
def
test_draw_boxes
(
self
):
def
test_draw_boxes
(
self
):
img
=
torch
.
full
((
3
,
100
,
100
),
255
,
dtype
=
torch
.
uint8
)
img
=
torch
.
full
((
3
,
100
,
100
),
255
,
dtype
=
torch
.
uint8
)
...
@@ -127,11 +126,11 @@ class Tester(unittest.TestCase):
...
@@ -127,11 +126,11 @@ class Tester(unittest.TestCase):
if
PILLOW_VERSION
>=
(
8
,
2
):
if
PILLOW_VERSION
>=
(
8
,
2
):
# The reference image is only valid for new PIL versions
# The reference image is only valid for new PIL versions
expected
=
torch
.
as_tensor
(
np
.
array
(
Image
.
open
(
path
))).
permute
(
2
,
0
,
1
)
expected
=
torch
.
as_tensor
(
np
.
array
(
Image
.
open
(
path
))).
permute
(
2
,
0
,
1
)
self
.
assert
True
(
torch
.
equal
(
result
,
expected
)
)
assert
_
equal
(
result
,
expected
)
# Check if modification is not in place
# Check if modification is not in place
self
.
assert
True
(
torch
.
all
(
torch
.
eq
(
boxes
,
boxes_cp
)
).
item
())
assert
_equal
(
boxes
,
boxes_cp
)
self
.
assert
True
(
torch
.
all
(
torch
.
eq
(
img
,
img_cp
)
).
item
())
assert
_equal
(
img
,
img_cp
)
def
test_draw_boxes_vanilla
(
self
):
def
test_draw_boxes_vanilla
(
self
):
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
...
@@ -145,10 +144,10 @@ class Tester(unittest.TestCase):
...
@@ -145,10 +144,10 @@ class Tester(unittest.TestCase):
res
.
save
(
path
)
res
.
save
(
path
)
expected
=
torch
.
as_tensor
(
np
.
array
(
Image
.
open
(
path
))).
permute
(
2
,
0
,
1
)
expected
=
torch
.
as_tensor
(
np
.
array
(
Image
.
open
(
path
))).
permute
(
2
,
0
,
1
)
self
.
assert
True
(
torch
.
equal
(
result
,
expected
)
)
assert
_
equal
(
result
,
expected
)
# Check if modification is not in place
# Check if modification is not in place
self
.
assert
True
(
torch
.
all
(
torch
.
eq
(
boxes
,
boxes_cp
)
).
item
())
assert
_equal
(
boxes
,
boxes_cp
)
self
.
assert
True
(
torch
.
all
(
torch
.
eq
(
img
,
img_cp
)
).
item
())
assert
_equal
(
img
,
img_cp
)
def
test_draw_invalid_boxes
(
self
):
def
test_draw_invalid_boxes
(
self
):
img_tp
=
((
1
,
1
,
1
),
(
1
,
2
,
3
))
img_tp
=
((
1
,
1
,
1
),
(
1
,
2
,
3
))
...
@@ -187,7 +186,7 @@ def test_draw_segmentation_masks(colors, alpha):
...
@@ -187,7 +186,7 @@ def test_draw_segmentation_masks(colors, alpha):
# Make sure the image didn't change where there's no mask
# Make sure the image didn't change where there's no mask
masked_pixels
=
masks
[
0
]
|
masks
[
1
]
masked_pixels
=
masks
[
0
]
|
masks
[
1
]
assert
(
img
[:,
~
masked_pixels
]
==
out
[:,
~
masked_pixels
])
.
all
()
assert
_equal
(
img
[:,
~
masked_pixels
]
,
out
[:,
~
masked_pixels
])
if
colors
is
None
:
if
colors
is
None
:
colors
=
utils
.
_generate_color_palette
(
num_masks
)
colors
=
utils
.
_generate_color_palette
(
num_masks
)
...
@@ -203,9 +202,8 @@ def test_draw_segmentation_masks(colors, alpha):
...
@@ -203,9 +202,8 @@ def test_draw_segmentation_masks(colors, alpha):
elif
alpha
==
0
:
elif
alpha
==
0
:
assert
(
out
[:,
mask
]
==
img
[:,
mask
]).
all
()
assert
(
out
[:,
mask
]
==
img
[:,
mask
]).
all
()
interpolated_color
=
(
img
[:,
mask
]
*
(
1
-
alpha
)
+
color
[:,
None
]
*
alpha
)
interpolated_color
=
(
img
[:,
mask
]
*
(
1
-
alpha
)
+
color
[:,
None
]
*
alpha
).
to
(
dtype
)
max_diff
=
(
out
[:,
mask
]
-
interpolated_color
).
abs
().
max
()
torch
.
testing
.
assert_close
(
out
[:,
mask
],
interpolated_color
,
rtol
=
0.0
,
atol
=
1.0
)
assert
max_diff
<=
1
def
test_draw_segmentation_masks_errors
():
def
test_draw_segmentation_masks_errors
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment