Unverified Commit 5486b768 authored by oxabz's avatar oxabz Committed by GitHub
Browse files

Throw warning for empty masks or box tensors on draw_segmentation_masks and...


Throw warning for empty masks or box tensors on draw_segmentation_masks and draw_bounding_boxes (#5857)

* Fixing the IndexError in draw_segmentation_masks

* fixing the bug on draw_bounding_boxes

* Changing fstring to normal string

* Removing unecessary conversion

* Adding test for the change

* Adding a test for draw seqmentation mask

* Fixing small mistake

* Fixing an error in the tests

* removing useless imports

* ufmt
Co-authored-by: default avatarLEGRAND Matthieu <legrand.ma@chu-toulouse.fr>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b969cca7
...@@ -176,6 +176,15 @@ def test_draw_boxes_warning(): ...@@ -176,6 +176,15 @@ def test_draw_boxes_warning():
utils.draw_bounding_boxes(img, boxes, font_size=11) utils.draw_bounding_boxes(img, boxes, font_size=11)
def test_draw_no_boxes():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
boxes = torch.full((0, 4), 0, dtype=torch.float)
with pytest.warns(UserWarning, match=re.escape("boxes doesn't contain any box. No box was drawn")):
res = utils.draw_bounding_boxes(img, boxes)
# Check that the function didnt change the image
assert res.eq(img).all()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"colors", "colors",
[ [
...@@ -266,6 +275,15 @@ def test_draw_segmentation_masks_errors(): ...@@ -266,6 +275,15 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
def test_draw_no_segmention_mask():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
masks = torch.full((0, 100, 100), 0, dtype=torch.bool)
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
res = utils.draw_segmentation_masks(img, masks)
# Check that the function didnt change the image
assert res.eq(img).all()
def test_draw_keypoints_vanilla(): def test_draw_keypoints_vanilla():
# Keypoints is declared on top as global variable # Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone() keypoints_cp = keypoints.clone()
......
...@@ -211,6 +211,10 @@ def draw_bounding_boxes( ...@@ -211,6 +211,10 @@ def draw_bounding_boxes(
num_boxes = boxes.shape[0] num_boxes = boxes.shape[0]
if num_boxes == 0:
warnings.warn("boxes doesn't contain any box. No box was drawn")
return image
if labels is None: if labels is None:
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
elif len(labels) != num_boxes: elif len(labels) != num_boxes:
...@@ -311,6 +315,10 @@ def draw_segmentation_masks( ...@@ -311,6 +315,10 @@ def draw_segmentation_masks(
if colors is not None and num_masks > len(colors): if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image
if colors is None: if colors is None:
colors = _generate_color_palette(num_masks) colors = _generate_color_palette(num_masks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment