"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "edf5ba6a17d012411c1fe3ceaf24f71f1899bc48"
Unverified Commit a75dc89a authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Fix annotation of draw_segmentation_masks (#4527)



* Add str param

* Update test to include str

* Fix mypy

* Remove a small bracket

* Test more robustly

* Update docstring and test:

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/utils.py

Small docstring fix

* Update torchvision/utils.py

* remove unnecessary renaming
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 4d711fdc
...@@ -162,6 +162,9 @@ def test_draw_invalid_boxes(): ...@@ -162,6 +162,9 @@ def test_draw_invalid_boxes():
"colors", "colors",
[ [
None, None,
"blue",
"#FF00FF",
(1, 34, 122),
["red", "blue"], ["red", "blue"],
["#FF00FF", (1, 34, 122)], ["#FF00FF", (1, 34, 122)],
], ],
...@@ -191,6 +194,8 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -191,6 +194,8 @@ def test_draw_segmentation_masks(colors, alpha):
if colors is None: if colors is None:
colors = utils._generate_color_palette(num_masks) colors = utils._generate_color_palette(num_masks)
elif isinstance(colors, str) or isinstance(colors, tuple):
colors = [colors]
# Make sure each mask draws with its own color # Make sure each mask draws with its own color
for mask, color in zip(masks, colors): for mask, color in zip(masks, colors):
......
...@@ -160,9 +160,9 @@ def draw_bounding_boxes( ...@@ -160,9 +160,9 @@ def draw_bounding_boxes(
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`. `0 <= ymin < ymax < H`.
labels (List[str]): List containing the labels of bounding boxes. labels (List[str]): List containing the labels of bounding boxes.
colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors colors (color or list of colors, optional): List containing the colors
or a single color for all of the bounding boxes. The colors can be represented as `str` or of the boxes or single color for all boxes. The color can be represented as
`Tuple[int, int, int]`. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
fill (bool): If `True` fills the bounding box with specified color. fill (bool): If `True` fills the bounding box with specified color.
width (int): Width of bounding box. width (int): Width of bounding box.
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
...@@ -231,7 +231,7 @@ def draw_segmentation_masks( ...@@ -231,7 +231,7 @@ def draw_segmentation_masks(
image: torch.Tensor, image: torch.Tensor,
masks: torch.Tensor, masks: torch.Tensor,
alpha: float = 0.8, alpha: float = 0.8,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -243,10 +243,10 @@ def draw_segmentation_masks( ...@@ -243,10 +243,10 @@ def draw_segmentation_masks(
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks. alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency. 0 means full transparency, 1 means no transparency.
colors (list or None): List containing the colors of the masks. The colors can colors (color or list of colors, optional): List containing the colors
be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. of the masks or single color for all masks. The color can be represented as
When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
with one element. By default, random colors are generated for each mask. By default, random colors are generated for each mask.
Returns: Returns:
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
...@@ -289,8 +289,7 @@ def draw_segmentation_masks( ...@@ -289,8 +289,7 @@ def draw_segmentation_masks(
for color in colors: for color in colors:
if isinstance(color, str): if isinstance(color, str):
color = ImageColor.getrgb(color) color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=out_dtype) colors_.append(torch.tensor(color, dtype=out_dtype))
colors_.append(color)
img_to_draw = image.detach().clone() img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this # TODO: There might be a way to vectorize this
...@@ -301,6 +300,6 @@ def draw_segmentation_masks( ...@@ -301,6 +300,6 @@ def draw_segmentation_masks(
return out.to(out_dtype) return out.to(out_dtype)
def _generate_color_palette(num_masks): def _generate_color_palette(num_masks: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)] return [tuple((i * palette) % 255) for i in range(num_masks)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment