"vscode:/vscode.git/clone" did not exist on "052edcecef3eb0ae9fe9e4b256fa2a488f9f395b"
Unverified Commit 30669a57 authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Add gallery example for drawing keypoints (#4892)



* Start writing gallery example

* Remove the child image fix implementation add code

* add docs

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* address review update thumbnail
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 85b78580
......@@ -4,10 +4,10 @@ Visualization utilities
=======================
This example illustrates some of the utilities that torchvision offers for
visualizing images, bounding boxes, and segmentation masks.
visualizing images, bounding boxes, segmentation masks and keypoints.
"""
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail.png"
# sphinx_gallery_thumbnail_path = "../../gallery/assets/visualization_utils_thumbnail2.png"
import torch
import numpy as np
......@@ -366,3 +366,110 @@ show(dogs_with_masks)
# The two 'people' masks in the first image where not selected because they have
# a lower score than the score threshold. Similarly in the second image, the
# instance with class 15 (which corresponds to 'bench') was not selected.
#####################################
# Visualizing keypoints
# ------------------------------
# The :func:`~torchvision.utils.draw_keypoints` function can be used to
# draw keypoints on images. We will see how to use it with
# torchvision's KeypointRCNN loaded with :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn`.
# We will first have a look at output of the model.
#
# Note that the keypoint detection model does not need normalized images.
#
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.io import read_image
person_int = read_image(str(Path("assets") / "person1.jpg"))
person_float = convert_image_dtype(person_int, dtype=torch.float)
model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
model = model.eval()
outputs = model([person_float])
print(outputs)
#####################################
# As we see the output contains a list of dictionaries.
# The output list is of length batch_size.
# We currently have just a single image so length of list is 1.
# Each entry in the list corresponds to an input image,
# and it is a dict with keys `boxes`, `labels`, `scores`, `keypoints` and `keypoint_scores`.
# Each value associated to those keys has `num_instances` elements in it.
# In our case above there are 2 instances detected in the image.
kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']
print(kpts)
print(scores)
#####################################
# The KeypointRCNN model detects there are two instances in the image.
# If you plot the boxes by using :func:`~draw_bounding_boxes`
# you would recognize they are the person and the surfboard.
# If we look at the scores, we will realize that the model is much more confident about the person than surfboard.
# We could now set a threshold confidence and plot instances which we are confident enough.
# Let us set a threshold of 0.75 and filter out the keypoints corresponding to the person.
detect_threshold = 0.75
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]
print(keypoints)
#####################################
# Great, now we have the keypoints corresponding to the person.
# Each keypoint is represented by x, y coordinates and the visibility.
# We can now use the :func:`~torchvision.utils.draw_keypoints` function to draw keypoints.
# Note that the utility expects uint8 images.
from torchvision.utils import draw_keypoints
res = draw_keypoints(person_int, keypoints, colors="blue", radius=3)
show(res)
#####################################
# As we see the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\
coco_keypoints = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle",
]
#####################################
# What if we are interested in joining the keypoints?
# This is especially useful in creating pose detection or action recognition.
# We can join the keypoints easily using the `connectivity` parameter.
# A close observation would reveal that we would need to join the points in below
# order to construct human skeleton.
#
# nose -> left_eye -> left_ear. (0, 1), (1, 3)
#
# nose -> right_eye -> right_ear. (0, 2), (2, 4)
#
# nose -> left_shoulder -> left_elbow -> left_wrist. (0, 5), (5, 7), (7, 9)
#
# nose -> right_shoulder -> right_elbow -> right_wrist. (0, 6), (6, 8), (8, 10)
#
# left_shoulder -> left_hip -> left_knee -> left_ankle. (5, 11), (11, 13), (13, 15)
#
# right_shoulder -> right_hip -> right_knee -> right_ankle. (6, 12), (12, 14), (14, 16)
#
# We will create a list containing these keypoint ids to be connected.
connect_skeleton = [
(0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6), (5, 7), (6, 8),
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), (13, 15), (14, 16)
]
#####################################
# We pass the above list to the connectivity parameter to connect the keypoints.
#
res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)
......@@ -256,7 +256,14 @@ def test_draw_keypoints_vanilla():
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors="red", connectivity=((0, 1),))
result = utils.draw_keypoints(
img,
keypoints,
colors="red",
connectivity=[
(0, 1),
],
)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_keypoint_vanilla.png")
if not os.path.exists(path):
res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
......@@ -277,7 +284,14 @@ def test_draw_keypoints_colored(colors):
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
result = utils.draw_keypoints(img, keypoints, colors=colors, connectivity=((0, 1),))
result = utils.draw_keypoints(
img,
keypoints,
colors=colors,
connectivity=[
(0, 1),
],
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
......
......@@ -304,7 +304,7 @@ def draw_segmentation_masks(
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
connectivity: Optional[Tuple[Tuple[int, int]]] = None,
connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
width: int = 3,
......@@ -318,7 +318,7 @@ def draw_keypoints(
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
in the format [x, y].
connectivity (Tuple[Tuple[int, int]]]): A Tuple of tuple where,
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment