Unverified Commit 3f33eeb1 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Support random colors by default for draw_bounding_boxes (#5127)



* Add random colors

* Update error message, pretty the code

* Update edge cases

* Change implementation to tuples

* Fix bugs

* Add tests

* Reuse palette

* small rename fix

* Update tests and code

* Simplify code

* ufmt

* fixed colors -> random colors in docstring

* Actually simplify further

* Silence mypy. Twice. lol.
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 435eddf7
......@@ -124,7 +124,7 @@ def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7)
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path):
......@@ -149,7 +149,11 @@ def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels_wrong = ["one", "two"]
colors_wrong = ["pink", "blue"]
with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"):
......@@ -158,6 +162,10 @@ def test_draw_invalid_boxes():
utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
with pytest.raises(ValueError, match="Number of boxes"):
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
with pytest.raises(ValueError, match="Number of colors"):
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
@pytest.mark.parametrize(
......
......@@ -176,6 +176,7 @@ def draw_bounding_boxes(
colors (color or list of colors, optional): List containing the colors
of the boxes or single color for all boxes. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for boxes.
fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
......@@ -198,45 +199,50 @@ def draw_bounding_boxes(
elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported")
num_boxes = boxes.shape[0]
if labels is None:
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
elif len(labels) != num_boxes:
raise ValueError(
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
)
if colors is None:
colors = _generate_color_palette(num_boxes)
elif isinstance(colors, list):
if len(colors) < num_boxes:
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
else: # colors specifies a single color for all boxes
colors = [colors] * num_boxes
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
# Handle Grayscale images
if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist()
if fill:
draw = ImageDraw.Draw(img_to_draw, "RGBA")
else:
draw = ImageDraw.Draw(img_to_draw)
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
for i, bbox in enumerate(img_boxes):
if colors is None:
color = None
elif isinstance(colors, list):
color = colors[i]
else:
color = colors
for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
if fill:
if color is None:
fill_color = (255, 255, 255, 100)
elif isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color) + (100,)
elif isinstance(color, tuple):
fill_color = color + (100,)
fill_color = color + (100,)
draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
else:
draw.rectangle(bbox, width=width, outline=color)
if labels is not None:
if label is not None:
margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font)
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
......@@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor:
return colorwheel
def _generate_color_palette(num_masks: int):
def _generate_color_palette(num_objects: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)]
return [tuple((i * palette) % 255) for i in range(num_objects)]
def _log_api_usage_once(obj: Any) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment