Unverified Commit 19ad0bbc authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Added utility to draw segmentation masks (#3330)



* add draw segm masks

* rewrites with new api

* fix flaky colors

* fix resize bug

* resize for sanity

* cleanup

* project the image

* Minor refactor to adopt num classes

* add uint8 in docstring

* adds alpha and docstring

* move code a bit down

* Minor fix

* fix type check

* Fixing resize bug.

* Fix type of alpha.

* Remove unnecessary RGBA conversions.

* update docs to supported only rgb

* minor edits

* adds tests

* shifts masks up

* change tests and impelementation for bool

* change mode to L

* convert to float

* fixes docs
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 0c051d0d
......@@ -7,4 +7,6 @@ torchvision.utils
.. autofunction:: save_image
.. autofunction:: draw_bounding_boxes
\ No newline at end of file
.. autofunction:: draw_bounding_boxes
.. autofunction:: draw_segmentation_masks
......@@ -9,6 +9,30 @@ from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image
masks = torch.tensor([
[
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799]
],
[
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-2.2799, -2.2799, -2.2799, -2.2799, -2.2799],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541]
],
[
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[-1.4541, -1.4541, -1.4541, -1.4541, -1.4541],
[5.0914, 5.0914, 5.0914, 5.0914, 5.0914],
]
], dtype=torch.float)
class Tester(unittest.TestCase):
......@@ -96,6 +120,35 @@ class Tester(unittest.TestCase):
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_colors_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
result = utils.draw_segmentation_masks(img, masks, colors=None)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_no_colors_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
if __name__ == '__main__':
unittest.main()
......@@ -6,7 +6,7 @@ import warnings
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageColor
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
@torch.no_grad()
......@@ -153,7 +153,7 @@ def draw_bounding_boxes(
If filled, Resulting Tensor should be saved as PNG image.
Args:
image (Tensor): Tensor of shape (C x H x W)
image (Tensor): Tensor of shape (C x H x W) and dtype uint8.
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
`0 <= ymin < ymax < H`.
......@@ -210,3 +210,61 @@ def draw_bounding_boxes(
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
@torch.no_grad()
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.2,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:
"""
Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
alpha (float): Float number between 0 and 1 denoting factor of transpaerency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
num_masks = masks.size()[0]
masks = masks.argmax(0)
if colors is None:
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")
else:
color_list = []
for color in colors:
if isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color)
color_list.append(fill_color)
elif isinstance(color, tuple):
color_list.append(color)
color_arr = np.array(color_list).astype("uint8")
_, h, w = image.size()
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
img_to_draw.putpalette(color_arr)
img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
img_to_draw = img_to_draw.permute((2, 0, 1))
return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment