Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2925df7c
Unverified
Commit
2925df7c
authored
Apr 21, 2023
by
Riza Velioglu
Committed by
GitHub
Apr 21, 2023
Browse files
fix color in draw_segmentation_masks (#7520)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
4344da3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
31 deletions
+52
-31
test/test_utils.py
test/test_utils.py
+6
-3
torchvision/utils.py
torchvision/utils.py
+46
-28
No files found.
test/test_utils.py
View file @
2925df7c
...
@@ -120,6 +120,9 @@ def test_draw_boxes_colors(colors):
...
@@ -120,6 +120,9 @@ def test_draw_boxes_colors(colors):
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
utils
.
draw_bounding_boxes
(
img
,
boxes
,
fill
=
False
,
width
=
7
,
colors
=
colors
)
utils
.
draw_bounding_boxes
(
img
,
boxes
,
fill
=
False
,
width
=
7
,
colors
=
colors
)
with
pytest
.
raises
(
ValueError
,
match
=
"Number of colors must be equal or larger than the number of objects"
):
utils
.
draw_bounding_boxes
(
image
=
img
,
boxes
=
boxes
,
colors
=
[])
def
test_draw_boxes_vanilla
():
def
test_draw_boxes_vanilla
():
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
img
=
torch
.
full
((
3
,
100
,
100
),
0
,
dtype
=
torch
.
uint8
)
...
@@ -268,12 +271,12 @@ def test_draw_segmentation_masks_errors():
...
@@ -268,12 +271,12 @@ def test_draw_segmentation_masks_errors():
with
pytest
.
raises
(
ValueError
,
match
=
"must have the same height and width"
):
with
pytest
.
raises
(
ValueError
,
match
=
"must have the same height and width"
):
masks_bad_shape
=
torch
.
randint
(
0
,
2
,
size
=
(
h
+
4
,
w
),
dtype
=
torch
.
bool
)
masks_bad_shape
=
torch
.
randint
(
0
,
2
,
size
=
(
h
+
4
,
w
),
dtype
=
torch
.
bool
)
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks_bad_shape
)
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks_bad_shape
)
with
pytest
.
raises
(
ValueError
,
match
=
"
There are more mask
s"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Number of colors must be equal or larger than the number of object
s"
):
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
[])
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
[])
with
pytest
.
raises
(
ValueError
,
match
=
"colors must be a tuple or a string, or a list thereof"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
`
colors
`
must be a tuple or a string, or a list thereof"
):
bad_colors
=
np
.
array
([
"red"
,
"blue"
])
# should be a list
bad_colors
=
np
.
array
([
"red"
,
"blue"
])
# should be a list
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
bad_colors
)
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
bad_colors
)
with
pytest
.
raises
(
ValueError
,
match
=
"I
t seems that you
passed a tuple
of
colors
instead of
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"I
f
passed a
s
tuple
,
colors
should be an RGB triplet
"
):
bad_colors
=
(
"red"
,
"blue"
)
# should be a list
bad_colors
=
(
"red"
,
"blue"
)
# should be a list
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
bad_colors
)
utils
.
draw_segmentation_masks
(
image
=
img
,
masks
=
masks
,
colors
=
bad_colors
)
...
...
torchvision/utils.py
View file @
2925df7c
...
@@ -217,15 +217,7 @@ def draw_bounding_boxes(
...
@@ -217,15 +217,7 @@ def draw_bounding_boxes(
f
"Number of boxes (
{
num_boxes
}
) and labels (
{
len
(
labels
)
}
) mismatch. Please specify labels for each box."
f
"Number of boxes (
{
num_boxes
}
) and labels (
{
len
(
labels
)
}
) mismatch. Please specify labels for each box."
)
)
if
colors
is
None
:
colors
=
_parse_colors
(
colors
,
num_objects
=
num_boxes
)
colors
=
_generate_color_palette
(
num_boxes
)
elif
isinstance
(
colors
,
list
):
if
len
(
colors
)
<
num_boxes
:
raise
ValueError
(
f
"Number of colors (
{
len
(
colors
)
}
) is less than number of boxes (
{
num_boxes
}
). "
)
else
:
# colors specifies a single color for all boxes
colors
=
[
colors
]
*
num_boxes
colors
=
[(
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
)
for
color
in
colors
]
if
font
is
None
:
if
font
is
None
:
if
font_size
is
not
None
:
if
font_size
is
not
None
:
...
@@ -307,34 +299,17 @@ def draw_segmentation_masks(
...
@@ -307,34 +299,17 @@ def draw_segmentation_masks(
raise
ValueError
(
"The image and the masks must have the same height and width"
)
raise
ValueError
(
"The image and the masks must have the same height and width"
)
num_masks
=
masks
.
size
()[
0
]
num_masks
=
masks
.
size
()[
0
]
if
colors
is
not
None
and
num_masks
>
len
(
colors
):
raise
ValueError
(
f
"There are more masks (
{
num_masks
}
) than colors (
{
len
(
colors
)
}
)"
)
if
num_masks
==
0
:
if
num_masks
==
0
:
warnings
.
warn
(
"masks doesn't contain any mask. No mask was drawn"
)
warnings
.
warn
(
"masks doesn't contain any mask. No mask was drawn"
)
return
image
return
image
if
colors
is
None
:
colors
=
_generate_color_palette
(
num_masks
)
if
not
isinstance
(
colors
,
list
):
colors
=
[
colors
]
if
not
isinstance
(
colors
[
0
],
(
tuple
,
str
)):
raise
ValueError
(
"colors must be a tuple or a string, or a list thereof"
)
if
isinstance
(
colors
[
0
],
tuple
)
and
len
(
colors
[
0
])
!=
3
:
raise
ValueError
(
"It seems that you passed a tuple of colors instead of a list of colors"
)
out_dtype
=
torch
.
uint8
out_dtype
=
torch
.
uint8
colors
=
[
torch
.
tensor
(
color
,
dtype
=
out_dtype
)
for
color
in
_parse_colors
(
colors
,
num_objects
=
num_masks
)]
colors_
=
[]
for
color
in
colors
:
if
isinstance
(
color
,
str
):
color
=
ImageColor
.
getrgb
(
color
)
colors_
.
append
(
torch
.
tensor
(
color
,
dtype
=
out_dtype
))
img_to_draw
=
image
.
detach
().
clone
()
img_to_draw
=
image
.
detach
().
clone
()
# TODO: There might be a way to vectorize this
# TODO: There might be a way to vectorize this
for
mask
,
color
in
zip
(
masks
,
colors
_
):
for
mask
,
color
in
zip
(
masks
,
colors
):
img_to_draw
[:,
mask
]
=
color
[:,
None
]
img_to_draw
[:,
mask
]
=
color
[:,
None
]
out
=
image
*
(
1
-
alpha
)
+
img_to_draw
*
alpha
out
=
image
*
(
1
-
alpha
)
+
img_to_draw
*
alpha
...
@@ -535,6 +510,49 @@ def _generate_color_palette(num_objects: int):
...
@@ -535,6 +510,49 @@ def _generate_color_palette(num_objects: int):
return
[
tuple
((
i
*
palette
)
%
255
)
for
i
in
range
(
num_objects
)]
return
[
tuple
((
i
*
palette
)
%
255
)
for
i
in
range
(
num_objects
)]
def
_parse_colors
(
colors
:
Union
[
None
,
str
,
Tuple
[
int
,
int
,
int
],
List
[
Union
[
str
,
Tuple
[
int
,
int
,
int
]]]],
*
,
num_objects
:
int
,
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
"""
Parses a specification of colors for a set of objects.
Args:
colors: A specification of colors for the objects. This can be one of the following:
- None: to generate a color palette automatically.
- A list of colors: where each color is either a string (specifying a named color) or an RGB tuple.
- A string or an RGB tuple: to use the same color for all objects.
If `colors` is a tuple, it should be a 3-tuple specifying the RGB values of the color.
If `colors` is a list, it should have at least as many elements as the number of objects to color.
num_objects (int): The number of objects to color.
Returns:
A list of 3-tuples, specifying the RGB values of the colors.
Raises:
ValueError: If the number of colors in the list is less than the number of objects to color.
If `colors` is not a list, tuple, string or None.
"""
if
colors
is
None
:
colors
=
_generate_color_palette
(
num_objects
)
elif
isinstance
(
colors
,
list
):
if
len
(
colors
)
<
num_objects
:
raise
ValueError
(
f
"Number of colors must be equal or larger than the number of objects, but got
{
len
(
colors
)
}
<
{
num_objects
}
."
)
elif
not
isinstance
(
colors
,
(
tuple
,
str
)):
raise
ValueError
(
"`colors` must be a tuple or a string, or a list thereof, but got {colors}."
)
elif
isinstance
(
colors
,
tuple
)
and
len
(
colors
)
!=
3
:
raise
ValueError
(
"If passed as tuple, colors should be an RGB triplet, but got {colors}."
)
else
:
# colors specifies a single color for all objects
colors
=
[
colors
]
*
num_objects
return
[
ImageColor
.
getrgb
(
color
)
if
isinstance
(
color
,
str
)
else
color
for
color
in
colors
]
def
_log_api_usage_once
(
obj
:
Any
)
->
None
:
def
_log_api_usage_once
(
obj
:
Any
)
->
None
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment