Unverified Commit 6c2e0ae8 authored by Mithra's avatar Mithra Committed by GitHub
Browse files

support of float dtypes for draw_segmentation_masks (#8150)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent c35d3855
...@@ -11,6 +11,7 @@ import torchvision.transforms.functional as F ...@@ -11,6 +11,7 @@ import torchvision.transforms.functional as F
import torchvision.utils as utils import torchvision.utils as utils
from common_utils import assert_equal, cpu_and_cuda from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
from torchvision.transforms.v2.functional import to_dtype
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
...@@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device): ...@@ -246,6 +247,26 @@ def test_draw_segmentation_masks(colors, alpha, device):
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_dtypes():
num_masks, h, w = 2, 100, 100
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
img_uint8 = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
out_uint8 = utils.draw_segmentation_masks(img_uint8, masks)
assert img_uint8 is not out_uint8
assert out_uint8.dtype == torch.uint8
img_float = to_dtype(img_uint8, torch.float, scale=True)
out_float = utils.draw_segmentation_masks(img_float, masks)
assert img_float is not out_float
assert out_float.is_floating_point()
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device): def test_draw_segmentation_masks_errors(device):
h, w = 10, 10 h, w = 10, 10
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from PIL import Image, ImageColor, ImageDraw, ImageFont from PIL import Image, ImageColor, ImageDraw, ImageFont
__all__ = [ __all__ = [
"make_grid", "make_grid",
"save_image", "save_image",
...@@ -262,10 +263,10 @@ def draw_segmentation_masks( ...@@ -262,10 +263,10 @@ def draw_segmentation_masks(
""" """
Draws segmentation masks on given RGB image. Draws segmentation masks on given RGB image.
The values of the input image should be uint8 between 0 and 255. The image values should be uint8 in [0, 255] or float in [0, 1].
Args: Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks. alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency. 0 means full transparency, 1 means no transparency.
...@@ -282,8 +283,8 @@ def draw_segmentation_masks( ...@@ -282,8 +283,8 @@ def draw_segmentation_masks(
_log_api_usage_once(draw_segmentation_masks) _log_api_usage_once(draw_segmentation_masks)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3: elif image.dim() != 3:
raise ValueError("Pass individual images, not batches") raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3: elif image.size()[0] != 3:
...@@ -303,10 +304,10 @@ def draw_segmentation_masks( ...@@ -303,10 +304,10 @@ def draw_segmentation_masks(
warnings.warn("masks doesn't contain any mask. No mask was drawn") warnings.warn("masks doesn't contain any mask. No mask was drawn")
return image return image
out_dtype = torch.uint8 original_dtype = image.dtype
colors = [ colors = [
torch.tensor(color, dtype=out_dtype, device=image.device) torch.tensor(color, dtype=original_dtype, device=image.device)
for color in _parse_colors(colors, num_objects=num_masks) for color in _parse_colors(colors, num_objects=num_masks, dtype=original_dtype)
] ]
img_to_draw = image.detach().clone() img_to_draw = image.detach().clone()
...@@ -315,7 +316,8 @@ def draw_segmentation_masks( ...@@ -315,7 +316,8 @@ def draw_segmentation_masks(
img_to_draw[:, mask] = color[:, None] img_to_draw[:, mask] = color[:, None]
out = image * (1 - alpha) + img_to_draw * alpha out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype) # Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return out.to(original_dtype)
@torch.no_grad() @torch.no_grad()
...@@ -516,6 +518,7 @@ def _parse_colors( ...@@ -516,6 +518,7 @@ def _parse_colors(
colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]], colors: Union[None, str, Tuple[int, int, int], List[Union[str, Tuple[int, int, int]]]],
*, *,
num_objects: int, num_objects: int,
dtype: torch.dtype = torch.uint8,
) -> List[Tuple[int, int, int]]: ) -> List[Tuple[int, int, int]]:
""" """
Parses a specification of colors for a set of objects. Parses a specification of colors for a set of objects.
...@@ -552,7 +555,10 @@ def _parse_colors( ...@@ -552,7 +555,10 @@ def _parse_colors(
else: # colors specifies a single color for all objects else: # colors specifies a single color for all objects
colors = [colors] * num_objects colors = [colors] * num_objects
return [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors] colors = [ImageColor.getrgb(color) if isinstance(color, str) else color for color in colors]
if dtype.is_floating_point: # [0, 255] -> [0, 1]
colors = [tuple(v / 255 for v in color) for color in colors]
return colors
def _log_api_usage_once(obj: Any) -> None: def _log_api_usage_once(obj: Any) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment