Unverified Commit 3f33eeb1 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Support random colors by default for draw_bounding_boxes (#5127)



* Add random colors

* Update error message, pretty the code

* Update edge cases

* Change implementation to tuples

* Fix bugs

* Add tests

* Reuse palette

* small rename fix

* Update tests and code

* Simplify code

* ufmt

* fixed colors -> random colors in docstring

* Actually simplify further

* Silence mypy. Twice. lol.
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 435eddf7
...@@ -124,7 +124,7 @@ def test_draw_boxes_vanilla(): ...@@ -124,7 +124,7 @@ def test_draw_boxes_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone() img_cp = img.clone()
boxes_cp = boxes.clone() boxes_cp = boxes.clone()
result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7) result = utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors="white")
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_boxes_vanilla.png")
if not os.path.exists(path): if not os.path.exists(path):
...@@ -149,7 +149,11 @@ def test_draw_invalid_boxes(): ...@@ -149,7 +149,11 @@ def test_draw_invalid_boxes():
img_tp = ((1, 1, 1), (1, 2, 3)) img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_correct = torch.zeros((3, 10, 10), dtype=torch.uint8)
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
labels_wrong = ["one", "two"]
colors_wrong = ["pink", "blue"]
with pytest.raises(TypeError, match="Tensor expected"): with pytest.raises(TypeError, match="Tensor expected"):
utils.draw_bounding_boxes(img_tp, boxes) utils.draw_bounding_boxes(img_tp, boxes)
with pytest.raises(ValueError, match="Tensor uint8 expected"): with pytest.raises(ValueError, match="Tensor uint8 expected"):
...@@ -158,6 +162,10 @@ def test_draw_invalid_boxes(): ...@@ -158,6 +162,10 @@ def test_draw_invalid_boxes():
utils.draw_bounding_boxes(img_wrong2, boxes) utils.draw_bounding_boxes(img_wrong2, boxes)
with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"): with pytest.raises(ValueError, match="Only grayscale and RGB images are supported"):
utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) utils.draw_bounding_boxes(img_wrong2[0][:2], boxes)
with pytest.raises(ValueError, match="Number of boxes"):
utils.draw_bounding_boxes(img_correct, boxes, labels_wrong)
with pytest.raises(ValueError, match="Number of colors"):
utils.draw_bounding_boxes(img_correct, boxes, colors=colors_wrong)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -176,6 +176,7 @@ def draw_bounding_boxes( ...@@ -176,6 +176,7 @@ def draw_bounding_boxes(
colors (color or list of colors, optional): List containing the colors colors (color or list of colors, optional): List containing the colors
of the boxes or single color for all boxes. The color can be represented as of the boxes or single color for all boxes. The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
By default, random colors are generated for boxes.
fill (bool): If `True` fills the bounding box with specified color. fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box. width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
...@@ -198,45 +199,50 @@ def draw_bounding_boxes( ...@@ -198,45 +199,50 @@ def draw_bounding_boxes(
elif image.size(0) not in {1, 3}: elif image.size(0) not in {1, 3}:
raise ValueError("Only grayscale and RGB images are supported") raise ValueError("Only grayscale and RGB images are supported")
num_boxes = boxes.shape[0]
if labels is None:
labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef]
elif len(labels) != num_boxes:
raise ValueError(
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box."
)
if colors is None:
colors = _generate_color_palette(num_boxes)
elif isinstance(colors, list):
if len(colors) < num_boxes:
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ")
else: # colors specifies a single color for all boxes
colors = [colors] * num_boxes
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors]
# Handle Grayscale images
if image.size(0) == 1: if image.size(0) == 1:
image = torch.tile(image, (3, 1, 1)) image = torch.tile(image, (3, 1, 1))
ndarr = image.permute(1, 2, 0).cpu().numpy() ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr) img_to_draw = Image.fromarray(ndarr)
img_boxes = boxes.to(torch.int64).tolist() img_boxes = boxes.to(torch.int64).tolist()
if fill: if fill:
draw = ImageDraw.Draw(img_to_draw, "RGBA") draw = ImageDraw.Draw(img_to_draw, "RGBA")
else: else:
draw = ImageDraw.Draw(img_to_draw) draw = ImageDraw.Draw(img_to_draw)
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)
for i, bbox in enumerate(img_boxes): for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type]
if colors is None:
color = None
elif isinstance(colors, list):
color = colors[i]
else:
color = colors
if fill: if fill:
if color is None: fill_color = color + (100,)
fill_color = (255, 255, 255, 100)
elif isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color) + (100,)
elif isinstance(color, tuple):
fill_color = color + (100,)
draw.rectangle(bbox, width=width, outline=color, fill=fill_color) draw.rectangle(bbox, width=width, outline=color, fill=fill_color)
else: else:
draw.rectangle(bbox, width=width, outline=color) draw.rectangle(bbox, width=width, outline=color)
if labels is not None: if label is not None:
margin = width + 1 margin = width + 1
draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=txt_font) draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
...@@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor: ...@@ -505,9 +511,9 @@ def _make_colorwheel() -> torch.Tensor:
return colorwheel return colorwheel
def _generate_color_palette(num_masks: int): def _generate_color_palette(num_objects: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)] return [tuple((i * palette) % 255) for i in range(num_objects)]
def _log_api_usage_once(obj: Any) -> None: def _log_api_usage_once(obj: Any) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment