Unverified Commit fa5b8446 authored by Mithra's avatar Mithra Committed by GitHub
Browse files

draw_keypoints() float support (#8276)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent c8c3839c
......@@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default():
assert_equal(result, expected)
def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)
out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)
assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8
assert out_float.is_floating_point()
assert out_float is not image_float
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
......
......@@ -336,13 +336,13 @@ def draw_keypoints(
"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
......@@ -363,7 +363,7 @@ def draw_keypoints(
For more details, see :ref:`draw_keypoints_with_visibility`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
......@@ -371,8 +371,8 @@ def draw_keypoints(
# validate image
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3:
raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3:
......@@ -397,6 +397,12 @@ def draw_keypoints(
f"Got {visibility.shape = } and {keypoints.shape = }"
)
original_dtype = image.dtype
if original_dtype.is_floating_point:
from torchvision.transforms.v2.functional import to_dtype # noqa
image = to_dtype(image, dtype=torch.uint8, scale=True)
ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
......@@ -428,7 +434,10 @@ def draw_keypoints(
width=width,
)
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
if original_dtype.is_floating_point:
out = to_dtype(out, dtype=original_dtype, scale=True)
return out
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment