Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6c2e0ae8
Unverified
Commit
6c2e0ae8
authored
Dec 18, 2023
by
Mithra
Committed by
GitHub
Dec 18, 2023
Browse files
support of float dtypes for draw_segmentation_masks (#8150)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
c35d3855
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
9 deletions
+36
-9
test/test_utils.py
test/test_utils.py
+21
-0
torchvision/utils.py
torchvision/utils.py
+15
-9
No files found.
test/test_utils.py
View file @
6c2e0ae8
...
...
@@ -11,6 +11,7 @@ import torchvision.transforms.functional as F
import
torchvision.utils
as
utils
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
PIL
import
__version__
as
PILLOW_VERSION
,
Image
,
ImageColor
from
torchvision.transforms.v2.functional
import
to_dtype
PILLOW_VERSION
=
tuple
(
int
(
x
)
for
x
in
PILLOW_VERSION
.
split
(
"."
))
...
...
@@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device):
torch
.
testing
.
assert_close
(
out
[:,
mask
],
interpolated_color
,
rtol
=
0.0
,
atol
=
1.0
)
def
test_draw_segmentation_masks_dtypes
():
num_masks
,
h
,
w
=
2
,
100
,
100
masks
=
torch
.
randint
(
0
,
2
,
(
num_masks
,
h
,
w
),
dtype
=
torch
.
bool
)
img_uint8
=
torch
.
randint
(
0
,
256
,
size
=
(
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
out_uint8
=
utils
.
draw_segmentation_masks
(
img_uint8
,
masks
)
assert
img_uint8
is
not
out_uint8
assert
out_uint8
.
dtype
==
torch
.
uint8
img_float
=
to_dtype
(
img_uint8
,
torch
.
float
,
scale
=
True
)
out_float
=
utils
.
draw_segmentation_masks
(
img_float
,
masks
)
assert
img_float
is
not
out_float
assert
out_float
.
is_floating_point
()
torch
.
testing
.
assert_close
(
out_uint8
,
to_dtype
(
out_float
,
torch
.
uint8
,
scale
=
True
),
rtol
=
0
,
atol
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_draw_segmentation_masks_errors
(
device
):
h
,
w
=
10
,
10
...
...
torchvision/utils.py
View file @
6c2e0ae8
...
...
@@ -10,6 +10,7 @@ import numpy as np
import
torch
from
PIL
import
Image
,
ImageColor
,
ImageDraw
,
ImageFont
__all__
=
[
"make_grid"
,
"save_image"
,
...
...
@@ -262,10 +263,10 @@ def draw_segmentation_masks(
"""
Draws segmentation masks on given RGB image.
The values
of the input image should be uint8 between 0 and 255
.
The
image
values
should be uint8 in [0, 255] or float in [0, 1]
.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8
or float
.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
...
...
@@ -282,8 +283,8 @@ def draw_segmentation_masks(
_log_api_usage_once
(
draw_segmentation_masks
)
if
not
isinstance
(
image
,
torch
.
Tensor
):
raise
TypeError
(
f
"The image must be a tensor, got
{
type
(
image
)
}
"
)
elif
image
.
dtype
!
=
torch
.
uint8
:
raise
ValueError
(
f
"The image dtype must be uint8, got
{
image
.
dtype
}
"
)
elif
not
(
image
.
dtype
=
=
torch
.
uint8
or
image
.
is_floating_point
())
:
raise
ValueError
(
f
"The image dtype must be uint8
or float
, got
{
image
.
dtype
}
"
)
elif
image
.
dim
()
!=
3
:
raise
ValueError
(
"Pass individual images, not batches"
)
elif
image
.
size
()[
0
]
!=
3
:
...
...
@@ -303,10 +304,10 @@ def draw_segmentation_masks(
warnings
.
warn
(
"masks doesn't contain any mask. No mask was drawn"
)
return
image
o
ut
_dtype
=
torch
.
uint8
o
riginal
_dtype
=
image
.
dtype
colors
=
[
torch
.
tensor
(
color
,
dtype
=
o
ut
_dtype
,
device
=
image
.
device
)
for
color
in
_parse_colors
(
colors
,
num_objects
=
num_masks
)
torch
.
tensor
(
color
,
dtype
=
o
riginal
_dtype
,
device
=
image
.
device
)
for
color
in
_parse_colors
(
colors
,
num_objects
=
num_masks
,
dtype
=
original_dtype
)
]
img_to_draw
=
image
.
detach
().
clone
()
...
...
@@ -315,7 +316,8 @@ def draw_segmentation_masks(
img_to_draw
[:,
mask
]
=
color
[:,
None
]
out
=
image
*
(
1
-
alpha
)
+
img_to_draw
*
alpha
return
out
.
to
(
out_dtype
)
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return
out
.
to
(
original_dtype
)
@
torch
.
no_grad
()
...
...
@@ -516,6 +518,7 @@ def _parse_colors(
colors
:
Union
[
None
,
str
,
Tuple
[
int
,
int
,
int
],
List
[
Union
[
str
,
Tuple
[
int
,
int
,
int
]]]],
*
,
num_objects
:
int
,
dtype
:
torch
.
dtype
=
torch
.
uint8
,
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
"""
Parses a specification of colors for a set of objects.
...
...
@@ -552,7 +555,10 @@ def _parse_colors(
else
:
# colors specifies a single color for all objects
colors
=
[
colors
]
*
num_objects
return
[
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
for
color
in
colors
]
colors
=
[
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
for
color
in
colors
]
if
dtype
.
is_floating_point
:
# [0, 255] -> [0, 1]
colors
=
[
tuple
(
v
/
255
for
v
in
color
)
for
color
in
colors
]
return
colors
def
_log_api_usage_once
(
obj
:
Any
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment