Unverified Commit 337fa343 authored by haarisr's avatar haarisr Committed by GitHub
Browse files

Allow K=1 in `draw_keypoints` (#8439)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 61d97f41
...@@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla(): ...@@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla():
assert_equal(img, img_cp) assert_equal(img, img_cp)
def test_draw_keypoins_K_equals_one():
# Non-regression test for https://github.com/pytorch/vision/pull/8439
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
keypoints = torch.tensor([[[10, 10]]], dtype=torch.float)
utils.draw_keypoints(img, keypoints)
@pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) @pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)])
def test_draw_keypoints_colored(colors): def test_draw_keypoints_colored(colors):
# Keypoints is declared on top as global variable # Keypoints is declared on top as global variable
......
...@@ -392,9 +392,9 @@ def draw_keypoints( ...@@ -392,9 +392,9 @@ def draw_keypoints(
# validate visibility # validate visibility
if visibility is None: # set default if visibility is None: # set default
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction if visibility.ndim == 3:
# model, make sure visibility has shape (num_instances, K). # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1).
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place. # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice.
visibility = visibility.squeeze(-1) visibility = visibility.squeeze(-1)
if visibility.ndim != 2: if visibility.ndim != 2:
raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment