Unverified Commit 337fa343 authored by haarisr's avatar haarisr Committed by GitHub
Browse files

Allow K=1 in `draw_keypoints` (#8439)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 61d97f41
......@@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla():
assert_equal(img, img_cp)
def test_draw_keypoins_K_equals_one():
# Non-regression test for https://github.com/pytorch/vision/pull/8439
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
utils.draw_keypoints(img, keypoints)
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable
......
......@@ -392,10 +392,10 @@ def draw_keypoints(
# validate visibility
if visibility is None: # set default
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
# model, make sure visibility has shape (num_instances, K).
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
visibility = visibility.squeeze(-1)
if visibility.ndim == 3:
# If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1).
# We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice.
visibility = visibility.squeeze(-1)
if visibility.ndim != 2:
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
if visibility.shape != keypoints.shape[:-1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment