Commit 169f098d authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Set keypoints not in the cropped image invisible (#1804)

* set invisiblity

* fix as comment
parent 6534efd6
...@@ -531,7 +531,8 @@ class CenterCrop(BaseTransform): ...@@ -531,7 +531,8 @@ class CenterCrop(BaseTransform):
results['gt_bboxes'] = gt_bboxes results['gt_bboxes'] = gt_bboxes
def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None: def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None:
"""Update key points according to CenterCrop. """Update key points according to CenterCrop. Keypoints that not in the
cropped image will be set invisible.
Args: Args:
results (dict): Result dict contains the data to transform. results (dict): Result dict contains the data to transform.
...@@ -544,6 +545,14 @@ class CenterCrop(BaseTransform): ...@@ -544,6 +545,14 @@ class CenterCrop(BaseTransform):
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order, # gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
# NK = number of points per object # NK = number of points per object
gt_keypoints = results['gt_keypoints'] - keypoints_offset gt_keypoints = results['gt_keypoints'] - keypoints_offset
# set gt_kepoints out of the result image invisible
height, width = results['img'].shape[:2]
valid_pos = (gt_keypoints[:, :, 0] >=
0) * (gt_keypoints[:, :, 0] <
width) * (gt_keypoints[:, :, 1] >= 0) * (
gt_keypoints[:, :, 1] < height)
gt_keypoints[:, :, 2] = np.where(valid_pos, gt_keypoints[:, :, 2],
0)
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0, gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
results['img'].shape[1]) results['img'].shape[1])
gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0, gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0,
......
...@@ -302,7 +302,7 @@ class TestCenterCrop: ...@@ -302,7 +302,7 @@ class TestCenterCrop:
224]])).all() 224]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all() np.array([[[0, 12, 0]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when size is tuple # test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224)) transform = dict(type='CenterCrop', crop_size=(224, 224))
...@@ -321,7 +321,7 @@ class TestCenterCrop: ...@@ -321,7 +321,7 @@ class TestCenterCrop:
224]])).all() 224]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all() np.array([[[0, 12, 0]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when crop_height != crop_width # test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(224, 256)) transform = dict(type='CenterCrop', crop_size=(224, 256))
...@@ -340,7 +340,7 @@ class TestCenterCrop: ...@@ -340,7 +340,7 @@ class TestCenterCrop:
256]])).all() 256]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 28, 1]], [[112, 128, 1]], [[212, 203, 1]]])).all() np.array([[[0, 28, 0]], [[112, 128, 1]], [[212, 203, 1]]])).all()
# test CenterCrop when crop_size is equal to img.shape # test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape img_height, img_width, _ = self.original_img.shape
...@@ -398,7 +398,7 @@ class TestCenterCrop: ...@@ -398,7 +398,7 @@ class TestCenterCrop:
300]])).all() 300]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all() np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
transform = dict( transform = dict(
type='CenterCrop', type='CenterCrop',
...@@ -418,7 +418,7 @@ class TestCenterCrop: ...@@ -418,7 +418,7 @@ class TestCenterCrop:
300]])).all() 300]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all() np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
# test CenterCrop when crop_width is smaller than img_width # test CenterCrop when crop_width is smaller than img_width
transform = dict( transform = dict(
...@@ -438,7 +438,7 @@ class TestCenterCrop: ...@@ -438,7 +438,7 @@ class TestCenterCrop:
300]])).all() 300]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all() np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
# test CenterCrop when crop_height is smaller than img_height # test CenterCrop when crop_height is smaller than img_height
transform = dict( transform = dict(
...@@ -457,7 +457,7 @@ class TestCenterCrop: ...@@ -457,7 +457,7 @@ class TestCenterCrop:
150]])).all() 150]])).all()
assert np.equal( assert np.equal(
results['gt_keypoints'], results['gt_keypoints'],
np.array([[[20, 0, 1]], [[200, 75, 1]], [[300, 150, 1]]])).all() np.array([[[20, 0, 0]], [[200, 75, 1]], [[300, 150, 0]]])).all()
@pytest.mark.skipif( @pytest.mark.skipif(
condition=torch is None, reason='No torch in current env') condition=torch is None, reason='No torch in current env')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment