Unverified Commit a4ca717f authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add utility to draw keypoints (#4216)

* fix

* Outline Keypoints API

* Add utility

* make it work :)

* Fix optional type

* Add connectivity, fmassa's advice 😃



* Minor code improvement

* small fix

* fix implementation

* Add tests

* Fix tests

* Update colors

* Fix bug and test more robustly

* Add a comment, merge stuff

* Fix fmt

* Support single str for merging

* Remove unnecessary vars.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 7408cb51
...@@ -14,5 +14,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`. ...@@ -14,5 +14,6 @@ vizualization <sphx_glr_auto_examples_plot_visualization_utils.py>`.
draw_bounding_boxes draw_bounding_boxes
draw_segmentation_masks draw_segmentation_masks
draw_keypoints
make_grid make_grid
save_image save_image
...@@ -16,6 +16,8 @@ PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) ...@@ -16,6 +16,8 @@ PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
def test_make_grid_not_inplace(): def test_make_grid_not_inplace():
t = torch.rand(5, 3, 10, 10) t = torch.rand(5, 3, 10, 10)
...@@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors(): ...@@ -248,5 +250,58 @@ def test_draw_segmentation_masks_errors():
utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors)
def test_draw_keypoints_vanilla():
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
res.save(path)
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
assert_equal(result, expected)
# Check that keypoints are not modified inplace
assert_equal(keypoints, keypoints_cp)
# Check that image is not modified in place
assert_equal(img, img_cp)
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
with pytest.raises(TypeError, match="The image must be a tensor"):
utils.draw_keypoints(image="Not A Tensor Image", keypoints=keypoints)
with pytest.raises(ValueError, match="The image dtype must be"):
img_bad_dtype = torch.full((3, h, w), 0, dtype=torch.int64)
utils.draw_keypoints(image=img_bad_dtype, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass individual images, not batches"):
batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=batch, keypoints=keypoints)
with pytest.raises(ValueError, match="Pass an RGB image"):
one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
utils.draw_keypoints(image=one_channel, keypoints=keypoints)
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from PIL import Image, ImageDraw, ImageFont, ImageColor from PIL import Image, ImageDraw, ImageFont, ImageColor
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] __all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks", "draw_keypoints"]
@torch.no_grad() @torch.no_grad()
...@@ -300,6 +300,76 @@ def draw_segmentation_masks( ...@@ -300,6 +300,76 @@ def draw_segmentation_masks(
return out.to(out_dtype) return out.to(out_dtype)
@torch.no_grad()
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
) -> torch.Tensor:
"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
width (int): Integer denoting width of line connecting keypoints.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")
if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")
ndarr = image.permute(1, 2, 0).numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()
for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)
if connectivity:
for connection in connectivity:
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]
end_pt_x = kpt_inst[connection[1]][0]
end_pt_y = kpt_inst[connection[1]][1]
draw.line(
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
width=width,
)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
def _generate_color_palette(num_masks: int): def _generate_color_palette(num_masks: int):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
return [tuple((i * palette) % 255) for i in range(num_masks)] return [tuple((i * palette) % 255) for i in range(num_masks)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment