Unverified Commit fa5b8446 authored by Mithra's avatar Mithra Committed by GitHub
Browse files

draw_keypoints() float support (#8276)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent c8c3839c
...@@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default(): ...@@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default():
assert_equal(result, expected) assert_equal(result, expected)
def test_draw_keypoints_dtypes():
image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
image_float = to_dtype(image_uint8, torch.float, scale=True)
out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
out_float = utils.draw_keypoints(image_float, keypoints)
assert out_uint8.dtype == torch.uint8
assert out_uint8 is not image_uint8
assert out_float.is_floating_point()
assert out_float is not image_float
torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
def test_draw_keypoints_errors(): def test_draw_keypoints_errors():
h, w = 10, 10 h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8) img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
......
...@@ -336,13 +336,13 @@ def draw_keypoints( ...@@ -336,13 +336,13 @@ def draw_keypoints(
""" """
Draws Keypoints on given RGB image. Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255. The image values should be uint8 in [0, 255] or float in [0, 1].
Keypoints can be drawn for multiple instances at a time. Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args: Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y]. in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
...@@ -363,7 +363,7 @@ def draw_keypoints( ...@@ -363,7 +363,7 @@ def draw_keypoints(
For more details, see :ref:`draw_keypoints_with_visibility`. For more details, see :ref:`draw_keypoints_with_visibility`.
Returns: Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
...@@ -371,8 +371,8 @@ def draw_keypoints( ...@@ -371,8 +371,8 @@ def draw_keypoints(
# validate image # validate image
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}") raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8: elif not (image.dtype == torch.uint8 or image.is_floating_point()):
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
elif image.dim() != 3: elif image.dim() != 3:
raise ValueError("Pass individual images, not batches") raise ValueError("Pass individual images, not batches")
elif image.size()[0] != 3: elif image.size()[0] != 3:
...@@ -397,6 +397,12 @@ def draw_keypoints( ...@@ -397,6 +397,12 @@ def draw_keypoints(
f"Got {visibility.shape = } and {keypoints.shape = }" f"Got {visibility.shape = } and {keypoints.shape = }"
) )
original_dtype = image.dtype
if original_dtype.is_floating_point:
from torchvision.transforms.v2.functional import to_dtype # noqa
image = to_dtype(image, dtype=torch.uint8, scale=True)
ndarr = image.permute(1, 2, 0).cpu().numpy() ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr) img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw) draw = ImageDraw.Draw(img_to_draw)
...@@ -428,7 +434,10 @@ def draw_keypoints( ...@@ -428,7 +434,10 @@ def draw_keypoints(
width=width, width=width,
) )
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
if original_dtype.is_floating_point:
out = to_dtype(out, dtype=original_dtype, scale=True)
return out
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment