Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6c2e0ae8
Unverified
Commit
6c2e0ae8
authored
Dec 18, 2023
by
Mithra
Committed by
GitHub
Dec 18, 2023
Browse files
support of float dtypes for draw_segmentation_masks (#8150)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
c35d3855
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
9 deletions
+36
-9
test/test_utils.py
test/test_utils.py
+21
-0
torchvision/utils.py
torchvision/utils.py
+15
-9
No files found.
test/test_utils.py
View file @
6c2e0ae8
...
@@ -11,6 +11,7 @@ import torchvision.transforms.functional as F
...
@@ -11,6 +11,7 @@ import torchvision.transforms.functional as F
import
torchvision.utils
as
utils
import
torchvision.utils
as
utils
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
common_utils
import
assert_equal
,
cpu_and_cuda
from
PIL
import
__version__
as
PILLOW_VERSION
,
Image
,
ImageColor
from
PIL
import
__version__
as
PILLOW_VERSION
,
Image
,
ImageColor
from
torchvision.transforms.v2.functional
import
to_dtype
PILLOW_VERSION
=
tuple
(
int
(
x
)
for
x
in
PILLOW_VERSION
.
split
(
"."
))
PILLOW_VERSION
=
tuple
(
int
(
x
)
for
x
in
PILLOW_VERSION
.
split
(
"."
))
...
@@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device):
...
@@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device):
torch
.
testing
.
assert_close
(
out
[:,
mask
],
interpolated_color
,
rtol
=
0.0
,
atol
=
1.0
)
torch
.
testing
.
assert_close
(
out
[:,
mask
],
interpolated_color
,
rtol
=
0.0
,
atol
=
1.0
)
def
test_draw_segmentation_masks_dtypes
():
num_masks
,
h
,
w
=
2
,
100
,
100
masks
=
torch
.
randint
(
0
,
2
,
(
num_masks
,
h
,
w
),
dtype
=
torch
.
bool
)
img_uint8
=
torch
.
randint
(
0
,
256
,
size
=
(
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
out_uint8
=
utils
.
draw_segmentation_masks
(
img_uint8
,
masks
)
assert
img_uint8
is
not
out_uint8
assert
out_uint8
.
dtype
==
torch
.
uint8
img_float
=
to_dtype
(
img_uint8
,
torch
.
float
,
scale
=
True
)
out_float
=
utils
.
draw_segmentation_masks
(
img_float
,
masks
)
assert
img_float
is
not
out_float
assert
out_float
.
is_floating_point
()
torch
.
testing
.
assert_close
(
out_uint8
,
to_dtype
(
out_float
,
torch
.
uint8
,
scale
=
True
),
rtol
=
0
,
atol
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_draw_segmentation_masks_errors
(
device
):
def
test_draw_segmentation_masks_errors
(
device
):
h
,
w
=
10
,
10
h
,
w
=
10
,
10
...
...
torchvision/utils.py
View file @
6c2e0ae8
...
@@ -10,6 +10,7 @@ import numpy as np
...
@@ -10,6 +10,7 @@ import numpy as np
import
torch
import
torch
from
PIL
import
Image
,
ImageColor
,
ImageDraw
,
ImageFont
from
PIL
import
Image
,
ImageColor
,
ImageDraw
,
ImageFont
__all__
=
[
__all__
=
[
"make_grid"
,
"make_grid"
,
"save_image"
,
"save_image"
,
...
@@ -262,10 +263,10 @@ def draw_segmentation_masks(
...
@@ -262,10 +263,10 @@ def draw_segmentation_masks(
"""
"""
Draws segmentation masks on given RGB image.
Draws segmentation masks on given RGB image.
The values
of the input image should be uint8 between 0 and 255
.
The
image
values
should be uint8 in [0, 255] or float in [0, 1]
.
Args:
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8
or float
.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
0 means full transparency, 1 means no transparency.
...
@@ -282,8 +283,8 @@ def draw_segmentation_masks(
...
@@ -282,8 +283,8 @@ def draw_segmentation_masks(
_log_api_usage_once
(
draw_segmentation_masks
)
_log_api_usage_once
(
draw_segmentation_masks
)
if
not
isinstance
(
image
,
torch
.
Tensor
):
if
not
isinstance
(
image
,
torch
.
Tensor
):
raise
TypeError
(
f
"The image must be a tensor, got
{
type
(
image
)
}
"
)
raise
TypeError
(
f
"The image must be a tensor, got
{
type
(
image
)
}
"
)
elif
image
.
dtype
!
=
torch
.
uint8
:
elif
not
(
image
.
dtype
=
=
torch
.
uint8
or
image
.
is_floating_point
())
:
raise
ValueError
(
f
"The image dtype must be uint8, got
{
image
.
dtype
}
"
)
raise
ValueError
(
f
"The image dtype must be uint8
or float
, got
{
image
.
dtype
}
"
)
elif
image
.
dim
()
!=
3
:
elif
image
.
dim
()
!=
3
:
raise
ValueError
(
"Pass individual images, not batches"
)
raise
ValueError
(
"Pass individual images, not batches"
)
elif
image
.
size
()[
0
]
!=
3
:
elif
image
.
size
()[
0
]
!=
3
:
...
@@ -303,10 +304,10 @@ def draw_segmentation_masks(
...
@@ -303,10 +304,10 @@ def draw_segmentation_masks(
warnings
.
warn
(
"masks doesn't contain any mask. No mask was drawn"
)
warnings
.
warn
(
"masks doesn't contain any mask. No mask was drawn"
)
return
image
return
image
o
ut
_dtype
=
torch
.
uint8
o
riginal
_dtype
=
image
.
dtype
colors
=
[
colors
=
[
torch
.
tensor
(
color
,
dtype
=
o
ut
_dtype
,
device
=
image
.
device
)
torch
.
tensor
(
color
,
dtype
=
o
riginal
_dtype
,
device
=
image
.
device
)
for
color
in
_parse_colors
(
colors
,
num_objects
=
num_masks
)
for
color
in
_parse_colors
(
colors
,
num_objects
=
num_masks
,
dtype
=
original_dtype
)
]
]
img_to_draw
=
image
.
detach
().
clone
()
img_to_draw
=
image
.
detach
().
clone
()
...
@@ -315,7 +316,8 @@ def draw_segmentation_masks(
...
@@ -315,7 +316,8 @@ def draw_segmentation_masks(
img_to_draw
[:,
mask
]
=
color
[:,
None
]
img_to_draw
[:,
mask
]
=
color
[:,
None
]
out
=
image
*
(
1
-
alpha
)
+
img_to_draw
*
alpha
out
=
image
*
(
1
-
alpha
)
+
img_to_draw
*
alpha
return
out
.
to
(
out_dtype
)
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return
out
.
to
(
original_dtype
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -516,6 +518,7 @@ def _parse_colors(
...
@@ -516,6 +518,7 @@ def _parse_colors(
colors
:
Union
[
None
,
str
,
Tuple
[
int
,
int
,
int
],
List
[
Union
[
str
,
Tuple
[
int
,
int
,
int
]]]],
colors
:
Union
[
None
,
str
,
Tuple
[
int
,
int
,
int
],
List
[
Union
[
str
,
Tuple
[
int
,
int
,
int
]]]],
*
,
*
,
num_objects
:
int
,
num_objects
:
int
,
dtype
:
torch
.
dtype
=
torch
.
uint8
,
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
"""
"""
Parses a specification of colors for a set of objects.
Parses a specification of colors for a set of objects.
...
@@ -552,7 +555,10 @@ def _parse_colors(
...
@@ -552,7 +555,10 @@ def _parse_colors(
else
:
# colors specifies a single color for all objects
else
:
# colors specifies a single color for all objects
colors
=
[
colors
]
*
num_objects
colors
=
[
colors
]
*
num_objects
return
[
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
for
color
in
colors
]
colors
=
[
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
for
color
in
colors
]
if
dtype
.
is_floating_point
:
# [0, 255] -> [0, 1]
colors
=
[
tuple
(
v
/
255
for
v
in
color
)
for
color
in
colors
]
return
colors
def
_log_api_usage_once
(
obj
:
Any
)
->
None
:
def
_log_api_usage_once
(
obj
:
Any
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment