Unverified Commit 2925df7c authored by Riza Velioglu's avatar Riza Velioglu Committed by GitHub
Browse files

fix color in draw_segmentation_masks (#7520)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 4344da3d
...@@ -120,6 +120,9 @@ def test_draw_boxes_colors(colors): ...@@ -120,6 +120,9 @@ def test_draw_boxes_colors(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_bounding_boxes(image=img, boxes=boxes, colors=[])
def test_draw_boxes_vanilla(): def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
...@@ -268,12 +271,12 @@ def test_draw_segmentation_masks_errors(): ...@@ -268,12 +271,12 @@ def test_draw_segmentation_masks_errors():
with pytest.raises(ValueError, match="must have the same height and width"): with pytest.raises(ValueError, match="must have the same height and width"):
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool) masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape) utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="There are more masks"): with pytest.raises(ValueError, match="Number of colors must be equal or larger than the number of objects"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): with pytest.raises(ValueError, match="`colors` must be a tuple or a string, or a list thereof"):
bad_colors = np.array(["red", "blue"]) # should be a list bad_colors = np.array(["red", "blue"]) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): with pytest.raises(ValueError, match="If passed as tuple, colors should be an RGB triplet"):
bad_colors = ("red", "blue") # should be a list bad_colors = ("red", "blue") # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
......
...@@ -217,15 +217,7 @@ def draw_bounding_boxes( ...@@ -217,15 +217,7 @@ def draw_bounding_boxes(
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
) )
if colors is None: colors = _parse_colors(colors, num_objects=num_boxes)
colors = _generate_color_palette(num_boxes)
elif isinstance(colors, list):
if len(colors) < num_boxes:
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
else: # colors specifies a single color for all boxes
colors = [colors] * num_boxes
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
if font is None: if font is None:
if font_size is not None: if font_size is not None:
...@@ -307,34 +299,17 @@ def draw_segmentation_masks( ...@@ -307,34 +299,17 @@ def draw_segmentation_masks(
raise ValueError("The image and the masks must have the same height and width") raise ValueError("The image and the masks must have the same height and width")
num_masks = masks.size()[0] num_masks = masks.size()[0]
if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
if num_masks == 0: if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn") warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image return image
if colors is None:
colors = _generate_color_palette(num_masks)
if not isinstance(colors, list):
colors = [colors]
if not isinstance(colors[0], (tuple, str)):
raise ValueError("colors must be a tuple or a string, or a list thereof")
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
out_dtype = torch.uint8 out_dtype = torch.uint8
colors = [torch.tensor(color, dtype=out_dtype) for color in _parse_colors(colors, num_objects=num_masks)]
colors_ = []
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
colors_.append(torch.tensor(color, dtype=out_dtype))
img_to_draw = image.detach().clone() img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this # TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors_): for mask, color in zip(masks, colors):
img_to_draw[:, mask] = color[:, None] img_to_draw[:, mask] = color[:, None]
out = image * (1 - alpha) + img_to_draw * alpha out = image * (1 - alpha) + img_to_draw * alpha
...@@ -535,6 +510,49 @@ def _generate_color_palette(num_objects: int): ...@@ -535,6 +510,49 @@ def _generate_color_palette(num_objects: int):
return [tuple((i * palette) % 255) for i in range(num_objects)] return [tuple((i * palette) % 255) for i in range(num_objects)]
def _parse_colors(
colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]],
*,
num_objects: int,
) -> List[Tuple[int, int, int]]:
"""
Parses a specification of colors for a set of objects.
Args:
colors: A specification of colors for the objects. This can be one of the following:
- None: to generate a color palette automatically.
- A list of colors: where each color is either a string (specifying a named color) or an RGB tuple.
- A string or an RGB tuple: to use the same color for all objects.
If `colors` is a tuple, it should be a 3-tuple specifying the RGB values of the color.
If `colors` is a list, it should have at least as many elements as the number of objects to color.
num_objects (int): The number of objects to color.
Returns:
A list of 3-tuples, specifying the RGB values of the colors.
Raises:
ValueError: If the number of colors in the list is less than the number of objects to color.
If `colors` is not a list, tuple, string or None.
"""
if colors is None:
colors = _generate_color_palette(num_objects)
elif isinstance(colors, list):
if len(colors) < num_objects:
raise ValueError(
f"Number of colors must be equal or larger than the number of objects, but got {len(colors)} < {num_objects}."
)
elif not isinstance(colors, (tuple, str)):
raise ValueError("`colors` must be a tuple or a string, or a list thereof, but got {colors}.")
elif isinstance(colors, tuple) and len(colors) != 3:
raise ValueError("If passed as tuple, colors should be an RGB triplet, but got {colors}.")
else: # colors specifies a single color for all objects
colors = [colors] * num_objects
return [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
def _log_api_usage_once(obj: Any) -> None: def _log_api_usage_once(obj: Any) -> None:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment