Unverified Commit 887b6f1f authored by Masahiro Hiramori's avatar Masahiro Hiramori Committed by GitHub
Browse files

Add GPU support for draw_segmentation_masks (#7684)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 463cdeab
...@@ -9,7 +9,7 @@ import pytest ...@@ -9,7 +9,7 @@ import pytest
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.utils as utils import torchvision.utils as utils
from common_utils import assert_equal from common_utils import assert_equal, cpu_and_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageColor from PIL import __version__ as PILLOW_VERSION, Image, ImageColor
...@@ -203,12 +203,13 @@ def test_draw_no_boxes(): ...@@ -203,12 +203,13 @@ def test_draw_no_boxes():
], ],
) )
@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1)) @pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1))
def test_draw_segmentation_masks(colors, alpha): @pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks(colors, alpha, device):
"""This test makes sure that masks draw their corresponding color where they should""" """This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100 num_masks, h, w = 2, 100, 100
dtype = torch.uint8 dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype) img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool) masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
# For testing we enforce that there's no overlap between the masks. The # For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when # current behaviour is that the last mask's color will take priority when
...@@ -234,7 +235,7 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -234,7 +235,7 @@ def test_draw_segmentation_masks(colors, alpha):
for mask, color in zip(masks, colors): for mask, color in zip(masks, colors):
if isinstance(color, str): if isinstance(color, str):
color = ImageColor.getrgb(color) color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=dtype) color = torch.tensor(color, dtype=dtype, device=device)
if alpha == 1: if alpha == 1:
assert (out[:, mask] == color[:, None]).all() assert (out[:, mask] == color[:, None]).all()
...@@ -245,11 +246,12 @@ def test_draw_segmentation_masks(colors, alpha): ...@@ -245,11 +246,12 @@ def test_draw_segmentation_masks(colors, alpha):
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0) torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
def test_draw_segmentation_masks_errors(): @pytest.mark.parametrize("device", cpu_and_cuda())
def test_draw_segmentation_masks_errors(device):
h, w = 10, 10 h, w = 10, 10
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool) masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool, device=device)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8) img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8, device=device)
with pytest.raises(TypeError, match="The image must be a tensor"): with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks) utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
...@@ -281,9 +283,10 @@ def test_draw_segmentation_masks_errors(): ...@@ -281,9 +283,10 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
def test_draw_no_segmention_mask(): @pytest.mark.parametrize("device", cpu_and_cuda())
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) def test_draw_no_segmention_mask(device):
masks = torch.full((0, 100, 100), 0, dtype=torch.bool) img = torch.full((3, 100, 100), 0, dtype=torch.uint8, device=device)
masks = torch.full((0, 100, 100), 0, dtype=torch.bool, device=device)
with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")): with pytest.warns(UserWarning, match=re.escape("masks doesn't contain any mask. No mask was drawn")):
res = utils.draw_segmentation_masks(img, masks) res = utils.draw_segmentation_masks(img, masks)
# Check that the function didn't change the image # Check that the function didn't change the image
......
...@@ -304,7 +304,10 @@ def draw_segmentation_masks( ...@@ -304,7 +304,10 @@ def draw_segmentation_masks(
return image return image
out_dtype = torch.uint8 out_dtype = torch.uint8
colors = [torch.tensor(color, dtype=out_dtype) for color in _parse_colors(colors, num_objects=num_masks)] colors = [
torch.tensor(color, dtype=out_dtype, device=image.device)
for color in _parse_colors(colors, num_objects=num_masks)
]
img_to_draw = image.detach().clone() img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this # TODO: There might be a way to vectorize this
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment