Unverified Commit 5bb997c6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Rewrite draw_segmentation_masks and update gallery example to illustrate both...

Rewrite draw_segmentation_masks and update gallery example to illustrate both instance and semantic segmentation models (#3824)
parent 32bccc53
......@@ -24,7 +24,8 @@ def show(imgs):
imgs = [imgs]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = F.to_pil_image(img.to('cpu'))
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
......@@ -50,9 +51,8 @@ show(grid)
# Visualizing bounding boxes
# --------------------------
# We can use :func:`~torchvision.utils.draw_bounding_boxes` to draw boxes on an
# image. We can set the colors, labels, width as well as font and font size !
# The boxes are in ``(xmin, ymin, xmax, ymax)`` format
# from torchvision.utils import draw_bounding_boxes
# image. We can set the colors, labels, width as well as font and font size.
# The boxes are in ``(xmin, ymin, xmax, ymax)`` format.
from torchvision.utils import draw_bounding_boxes
......@@ -74,9 +74,8 @@ from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms.functional import convert_image_dtype
dog1_float = convert_image_dtype(dog1_int, dtype=torch.float)
dog2_float = convert_image_dtype(dog2_int, dtype=torch.float)
batch = torch.stack([dog1_float, dog2_float])
batch_int = torch.stack([dog1_int, dog2_int])
batch = convert_image_dtype(batch_int, dtype=torch.float)
model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()
......@@ -91,7 +90,7 @@ print(outputs)
threshold = .8
dogs_with_boxes = [
draw_bounding_boxes(dog_int, boxes=output['boxes'][output['scores'] > threshold], width=4)
for dog_int, output in zip((dog1_int, dog2_int), outputs)
for dog_int, output in zip(batch_int, outputs)
]
show(dogs_with_boxes)
......@@ -99,33 +98,255 @@ show(dogs_with_boxes)
# Visualizing segmentation masks
# ------------------------------
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
# draw segmentation amasks on images. We can set the colors as well as
# transparency of masks.
# draw segmentation masks on images. Semantic segmentation and instance
# segmentation models have different outputs, so we will treat each
# independently.
#
# Here is demo with torchvision's FCN Resnet-50, loaded with
# :func:`~torchvision.models.segmentation.fcn_resnet50`.
# You can also try using
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`)
# or lraspp mobilenet models
# Semantic segmentation models
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We will see how to use it with torchvision's FCN Resnet-50, loaded with
# :func:`~torchvision.models.segmentation.fcn_resnet50`. You can also try using
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`) or
# lraspp mobilenet models
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
#
# Like :func:`~torchvision.utils.draw_bounding_boxes`,
# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image
# of dtype `uint8`.
# Let's start by looking at the ouput of the model. Remember that in general,
# images must be normalized before they're passed to a semantic segmentation
# model.
from torchvision.models.segmentation import fcn_resnet50
from torchvision.utils import draw_segmentation_masks
model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()
# The model expects the batch to be normalized
batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
outputs = model(batch)
normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
output = model(normalized_batch)['out']
print(output.shape, output.min().item(), output.max().item())
#####################################
# As we can see above, the output of the segmentation model is a tensor of shape
# ``(batch_size, num_classes, H, W)``. Each value is a non-normalized score, and
# we can normalize them into ``[0, 1]`` by using a softmax. After the softmax,
# we can interpret each value as a probability indicating how likely a given
# pixel is to belong to a given class.
#
# Let's plot the masks that have been detected for the dog class and for the
# boat class:
sem_classes = [
'__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}
normalized_masks = torch.nn.functional.softmax(output, dim=1)
dog_and_boat_masks = [
normalized_masks[img_idx, sem_class_to_idx[cls]]
for img_idx in range(batch.shape[0])
for cls in ('dog', 'boat')
]
show(dog_and_boat_masks)
#####################################
# As expected, the model is confident about the dog class, but not so much for
# the boat class.
#
# The :func:`~torchvision.utils.draw_segmentation_masks` function can be used to
# plots those masks on top of the original image. This function expects the
# masks to be boolean masks, but our masks above contain probabilities in ``[0,
# 1]``. To get boolean masks, we can do the following:
class_dim = 1
boolean_dog_masks = (normalized_masks.argmax(class_dim) == sem_class_to_idx['dog'])
print(f"shape = {boolean_dog_masks.shape}, dtype = {boolean_dog_masks.dtype}")
show([m.float() for m in boolean_dog_masks])
#####################################
# The line above where we define ``boolean_dog_masks`` is a bit cryptic, but you
# can read it as the following query: "For which pixels is 'dog' the most likely
# class?"
#
# .. note::
# While we're using the ``normalized_masks`` here, we would have
# gotten the same result by using the non-normalized scores of the model
# directly (as the softmax operation preserves the order).
#
# Now that we have boolean masks, we can use them with
# :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the
# original images:
from torchvision.utils import draw_segmentation_masks
dogs_with_masks = [
draw_segmentation_masks(img, masks=mask, alpha=0.7)
for img, mask in zip(batch_int, boolean_dog_masks)
]
show(dogs_with_masks)
#####################################
# We can plot more than one mask per image! Remember that the model returned as
# many masks as there are classes. Let's ask the same query as above, but this
# time for *all* classes, not just the dog class: "For each pixel and each class
# C, is class C the most most likely class?"
#
# This one is a bit more involved, so we'll first show how to do it with a
# single image, and then we'll generalize to the batch
num_classes = normalized_masks.shape[1]
dog1_masks = normalized_masks[0]
class_dim = 0
dog1_all_classes_masks = dog1_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None]
print(f"dog1_masks shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}")
print(f"dog1_all_classes_masks = {dog1_all_classes_masks.shape}, dtype = {dog1_all_classes_masks.dtype}")
dog_with_all_masks = draw_segmentation_masks(dog1_int, masks=dog1_all_classes_masks, alpha=.6)
show(dog_with_all_masks)
#####################################
# We can see in the image above that only 2 masks were drawn: the mask for the
# background and the mask for the dog. This is because the model thinks that
# only these 2 classes are the most likely ones across all the pixels. If the
# model had detected another class as the most likely among other pixels, we
# would have seen its mask above.
#
# Removing the background mask is as simple as passing
# ``masks=dog1_all_classes_masks[1:]``, because the background class is the
# class with index 0.
#
# Let's now do the same but for an entire batch of images. The code is similar
# but involves a bit more juggling with the dimensions.
class_dim = 1
all_classes_masks = normalized_masks.argmax(class_dim) == torch.arange(num_classes)[:, None, None, None]
print(f"shape = {all_classes_masks.shape}, dtype = {all_classes_masks.dtype}")
# The first dimension is the classes now, so we need to swap it
all_classes_masks = all_classes_masks.swapaxes(0, 1)
dogs_with_masks = [
draw_segmentation_masks(dog_int, masks=masks, alpha=0.6)
for dog_int, masks in zip((dog1_int, dog2_int), outputs['out'])
draw_segmentation_masks(img, masks=mask, alpha=.6)
for img, mask in zip(batch_int, all_classes_masks)
]
show(dogs_with_masks)
#####################################
# Instance segmentation models
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Instance segmentation models have a significantly different output from the
# semantic segmentation models. We will see here how to plot the masks for such
# models. Let's start by analyzing the output of a Mask-RCNN model. Note that
# these models don't require the images to be normalized, so we don't need to
# use the normalized batch.
from torchvision.models.detection import maskrcnn_resnet50_fpn
model = maskrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()
output = model(batch)
print(output)
#####################################
# Let's break this down. For each image in the batch, the model outputs some
# detections (or instances). The number of detection varies for each input
# image. Each instance is described by its bounding box, its label, its score
# and its mask.
#
# The way the output is organized is as follows: the output is a list of length
# ``batch_size``. Each entry in the list corresponds to an input image, and it
# is a dict with keys 'boxes', 'labels', 'scores', and 'masks'. Each value
# associated to those keys has ``num_instances`` elements in it. In our case
# above there are 3 instances detected in the first image, and 2 instances in
# the second one.
#
# The boxes can be plotted with :func:`~torchvision.utils.draw_bounding_boxes`
# as above, but here we're more interested in the masks. These masks are quite
# different from the masks that we saw above for the semantic segmentation
# models.
dog1_output = output[0]
dog1_masks = dog1_output['masks']
print(f"shape = {dog1_masks.shape}, dtype = {dog1_masks.dtype}, "
f"min = {dog1_masks.min()}, max = {dog1_masks.max()}")
#####################################
# Here the masks corresponds to probabilities indicating, for each pixel, how
# likely it is to belong to the predicted label of that instance. Those
# predicted labels correspond to the 'labels' element in the same output dict.
# Let's see which labels were predicted for the instances of the first image.
inst_classes = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
inst_class_to_idx = {cls: idx for (idx, cls) in enumerate(inst_classes)}
print("For the first dog, the following instances were detected:")
print([inst_classes[label] for label in dog1_output['labels']])
#####################################
# Interestingly, the model detects two persons in the image. Let's go ahead and
# plot those masks. Since :func:`~torchvision.utils.draw_segmentation_masks`
# expects boolean masks, we need to convert those probabilities into boolean
# values. Remember that the semantic of those masks is "How likely is this pixel
# to belong to the predicted class?". As a result, a natural way of converting
# those masks into boolean values is to threshold them with the 0.5 probability
# (one could also choose a different threshold).
proba_threshold = 0.5
dog1_bool_masks = dog1_output['masks'] > proba_threshold
print(f"shape = {dog1_bool_masks.shape}, dtype = {dog1_bool_masks.dtype}")
# There's an extra dimension (1) to the masks. We need to remove it
dog1_bool_masks = dog1_bool_masks.squeeze(1)
show(draw_segmentation_masks(dog1_int, dog1_bool_masks, alpha=0.9))
#####################################
# The model seems to have properly detected the dog, but it also confused trees
# with people. Looking more closely at the scores will help us plotting more
# relevant masks:
print(dog1_output['scores'])
#####################################
# Clearly the model is less confident about the dog detection than it is about
# the people detections. That's good news. When plotting the masks, we can ask
# for only those that have a good score. Let's use a score threshold of .75
# here, and also plot the masks of the second dog.
score_threshold = .75
boolean_masks = [
out['masks'][out['scores'] > score_threshold] > proba_threshold
for out in output
]
dogs_with_masks = [
draw_segmentation_masks(img, mask.squeeze(1))
for img, mask in zip(batch_int, boolean_masks)
]
show(dogs_with_masks)
#####################################
# The two 'people' masks in the first image where not selected because they have
# a lower score than the score threshold. Similarly in the second image, the
# instance with class 15 (which corresponds to 'bench') was not selected.
import pytest
import numpy as np
import os
import sys
......@@ -7,7 +8,7 @@ import torchvision.utils as utils
import unittest
from io import BytesIO
import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
......@@ -159,55 +160,88 @@ class Tester(unittest.TestCase):
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong1, boxes)
self.assertRaises(ValueError, utils.draw_bounding_boxes, img_wrong2, boxes)
def test_draw_segmentation_masks_colors(self):
img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
colors = ["#FF00FF", (0, 255, 0), "red"]
result = utils.draw_segmentation_masks(img, masks, colors=colors)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_colors_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
def test_draw_segmentation_masks_no_colors(self):
img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
img_cp = img.clone()
masks_cp = masks.clone()
result = utils.draw_segmentation_masks(img, masks, colors=None)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
"fakedata", "draw_segm_masks_no_colors_util.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
self.assertTrue(torch.equal(result, expected))
# Check if modification is not in place
self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
def test_draw_invalid_masks(self):
img_tp = ((1, 1, 1), (1, 2, 3))
img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float)
img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8)
img_wrong3 = torch.full((4, 5, 5), 255, dtype=torch.uint8)
self.assertRaises(TypeError, utils.draw_segmentation_masks, img_tp, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong1, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong2, masks)
self.assertRaises(ValueError, utils.draw_segmentation_masks, img_wrong3, masks)
@pytest.mark.parametrize('colors', [
None,
['red', 'blue'],
['#FF00FF', (1, 34, 122)],
])
@pytest.mark.parametrize('alpha', (0, .5, .7, 1))
def test_draw_segmentation_masks(colors, alpha):
"""This test makes sure that masks draw their corresponding color where they should"""
num_masks, h, w = 2, 100, 100
dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)
# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False
out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
assert out is not img
# Make sure the image didn't change where there's no mask
masked_pixels = masks[0] | masks[1]
assert (img[:, ~masked_pixels] == out[:, ~masked_pixels]).all()
if colors is None:
colors = utils._generate_color_palette(num_masks)
# Make sure each mask draws with its own color
for mask, color in zip(masks, colors):
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=dtype)
if alpha == 1:
assert (out[:, mask] == color[:, None]).all()
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha)
max_diff = (out[:, mask] - interpolated_color).abs().max()
assert max_diff <= 1
def test_draw_segmentation_masks_errors():
h, w = 10, 10
masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool)
img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)
with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.randint(0, 256, size=(3, h, w), dtype=torch.int64)
utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_segmentation_masks(image=batch, masks=masks)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_segmentation_masks(image=one_channel, masks=masks)
with pytest.raises(ValueError, match="The masks must be of dtype bool"):
masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
with pytest.raises(ValueError, match="masks must be of shape"):
masks_bad_shape = torch.randint(0, 2, size=(3, 2, h, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="must have the same height and width"):
masks_bad_shape = torch.randint(0, 2, size=(h + 4, w), dtype=torch.bool)
utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
with pytest.raises(ValueError, match="There are more masks"):
utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"):
bad_colors = np.array(['red', 'blue']) # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"):
bad_colors = ('red', 'blue') # should be a list
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
if __name__ == '__main__':
......
......@@ -220,7 +220,7 @@ def draw_bounding_boxes(
def draw_segmentation_masks(
image: torch.Tensor,
masks: torch.Tensor,
alpha: float = 0.2,
alpha: float = 0.8,
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:
......@@ -229,49 +229,68 @@ def draw_segmentation_masks(
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3 x H x W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W). Each containing probability of predicted class.
alpha (float): Float number between 0 and 1 denoting factor of transparency of masks.
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of masks. The colors can
be represented as `str` or `Tuple[int, int, int]`.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
0 means full transparency, 1 means no transparency.
colors (list or None): List containing the colors of the masks. The colors can
be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
with one element. By default, random colors are generated for each mask.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with segmentation masks plotted.
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Tensor expected, got {type(image)}")
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"Tensor uint8 expected, got {image.dtype}")
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if masks.ndim == 2:
masks = masks[None, :, :]
if masks.ndim != 3:
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
if masks.dtype != torch.bool:
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
if masks.shape[-2:] != image.shape[-2:]:
raise ValueError("The image and the masks must have the same height and width")
num_masks = masks.size()[0]
masks = masks.argmax(0)
if colors is not None and num_masks > len(colors):
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
if colors is None:
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palette
color_arr = (colors_t % 255).numpy().astype("uint8")
else:
color_list = []
for color in colors:
if isinstance(color, str):
# This will automatically raise Error if rgb cannot be parsed.
fill_color = ImageColor.getrgb(color)
color_list.append(fill_color)
elif isinstance(color, tuple):
color_list.append(color)
colors = _generate_color_palette(num_masks)
if not isinstance(colors, list):
colors = [colors]
if not isinstance(colors[0], (tuple, str)):
raise ValueError("colors must be a tuple or a string, or a list thereof")
if isinstance(colors[0], tuple) and len(colors[0]) != 3:
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
out_dtype = torch.uint8
colors_ = []
for color in colors:
if isinstance(color, str):
color = ImageColor.getrgb(color)
color = torch.tensor(color, dtype=out_dtype)
colors_.append(color)
color_arr = np.array(color_list).astype("uint8")
img_to_draw = image.detach().clone()
# TODO: There might be a way to vectorize this
for mask, color in zip(masks, colors_):
img_to_draw[:, mask] = color[:, None]
_, h, w = image.size()
img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))
img_to_draw.putpalette(color_arr)
out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))
img_to_draw = img_to_draw.permute((2, 0, 1))
return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)
def _generate_color_palette(num_masks):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment