Unverified Commit 30669a57 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add gallery example for drawing keypoints (#4892)



* Start writing gallery example

* Remove the child image fix implementation add code

* add docs

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* address review update thumbnail
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 85b78580
...@@ -4,10 +4,10 @@ Visualization utilities ...@@ -4,10 +4,10 @@ Visualization utilities
======================= =======================
This example illustrates some of the utilities that torchvision offers for This example illustrates some of the utilities that torchvision offers for
visualizing images, bounding boxes, and segmentation masks. visualizing images, bounding boxes, segmentation masks and keypoints.
""" """
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png" # sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png"
import torch import torch
import numpy as np import numpy as np
...@@ -366,3 +366,110 @@ show(dogs_with_masks) ...@@ -366,3 +366,110 @@ show(dogs_with_masks)
# The two 'people' masks in the first image where not selected because they have # The two 'people' masks in the first image where not selected because they have
# a lower score than the score threshold. Similarly in the second image, the # a lower score than the score threshold. Similarly in the second image, the
# instance with class 15 (which corresponds to 'bench') was not selected. # instance with class 15 (which corresponds to 'bench') was not selected.
#####################################
# Visualizing keypoints
# ------------------------------
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
# draw keypoints on images. We will see how to use it with
# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
# We will first have a look at output of the model.
#
# Note that the keypoint detection model does not need normalized images.
#
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.io import read_image
person_int = read_image(str(Path("assets") / "person1.jpg"))
person_float = convert_image_dtype(person_int, dtype=torch.float)
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()
outputs = model([person_float])
print(outputs)
#####################################
# As we see the output contains a list of dictionaries.
# The output list is of length batch_size.
# We currently have just a single image so length of list is 1.
# Each entry in the list corresponds to an input image,
# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
# Each value associated to those keys has `num_instances` elements in it.
# In our case above there are 2 instances detected in the image.
kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']
print(kpts)
print(scores)
#####################################
# The KeypointRCNN model detects there are two instances in the image.
# If you plot the boxes by using :func:`~draw_bounding_boxes`
# you would recognize they are the person and the surfboard.
# If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
# We could now set a threshold confidence and plot instances which we are confident enough.
# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.
detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]
print(keypoints)
#####################################
# Great, now we have the keypoints corresponding to the person.
# Each keypoint is represented by x, y coordinates and the visibility.
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
# Note that the utility expects uint8 images.
from torchvision.utils import draw_keypoints
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)
#####################################
# As we see the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\
coco_keypoints = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
]
#####################################
# What if we are interested in joining the keypoints?
# This is especially useful in creating pose detection or action recognition.
# We can join the keypoints easily using the `connectivity` parameter.
# A close observation would reveal that we would need to join the points in below
# order to construct human skeleton.
#
# nose -> left_eye -> left_ear. (0, 1), (1, 3)
#
# nose -> right_eye -> right_ear. (0, 2), (2, 4)
#
# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9)
#
# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10)
#
# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15)
#
# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16)
#
# We will create a list containing these keypoint ids to be connected.
connect_skeleton = [
(0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
]
#####################################
# We pass the above list to the connectivity parameter to connect the keypoints.
#
res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)
...@@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla(): ...@@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone() img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),)) result = utils.draw_keypoints(
img,
keypoints,
colors="red",
connectivity=[
(0, 1),
],
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path): if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
...@@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors): ...@@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone() img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),)) result = utils.draw_keypoints(
img,
keypoints,
colors=colors,
connectivity=[
(0, 1),
],
)
assert result.size(0) == 3 assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp) assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp) assert_equal(img, img_cp)
......
...@@ -304,7 +304,7 @@ def draw_segmentation_masks( ...@@ -304,7 +304,7 @@ def draw_segmentation_masks(
def draw_keypoints( def draw_keypoints(
image: torch.Tensor, image: torch.Tensor,
keypoints: torch.Tensor, keypoints: torch.Tensor,
connectivity: Optional[Tuple[Tuple[int, int]]] = None, connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None, colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2, radius: int = 2,
width: int = 3, width: int = 3,
...@@ -318,7 +318,7 @@ def draw_keypoints( ...@@ -318,7 +318,7 @@ def draw_keypoints(
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y]. in the format [x, y].
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where, connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected. each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment