Commit 169f098d authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Set keypoints not in the cropped image invisible (#1804)

* set invisiblity

* fix as comment
parent 6534efd6
......@@ -531,7 +531,8 @@ class CenterCrop(BaseTransform):
results['gt_bboxes'] = gt_bboxes
def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None:
"""Update key points according to CenterCrop.
"""Update key points according to CenterCrop. Keypoints that not in the
cropped image will be set invisible.
Args:
results (dict): Result dict contains the data to transform.
......@@ -544,6 +545,14 @@ class CenterCrop(BaseTransform):
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
# NK = number of points per object
gt_keypoints = results['gt_keypoints'] - keypoints_offset
# set gt_kepoints out of the result image invisible
height, width = results['img'].shape[:2]
valid_pos = (gt_keypoints[:, :, 0] >=
0) * (gt_keypoints[:, :, 0] <
width) * (gt_keypoints[:, :, 1] >= 0) * (
gt_keypoints[:, :, 1] < height)
gt_keypoints[:, :, 2] = np.where(valid_pos, gt_keypoints[:, :, 2],
0)
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
results['img'].shape[1])
gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0,
......
......@@ -302,7 +302,7 @@ class TestCenterCrop:
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
np.array([[[0, 12, 0]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224))
......@@ -321,7 +321,7 @@ class TestCenterCrop:
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
np.array([[[0, 12, 0]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(224, 256))
......@@ -340,7 +340,7 @@ class TestCenterCrop:
256]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 28, 1]], [[112, 128, 1]], [[212, 203, 1]]])).all()
np.array([[[0, 28, 0]], [[112, 128, 1]], [[212, 203, 1]]])).all()
# test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape
......@@ -398,7 +398,7 @@ class TestCenterCrop:
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
transform = dict(
type='CenterCrop',
......@@ -418,7 +418,7 @@ class TestCenterCrop:
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
# test CenterCrop when crop_width is smaller than img_width
transform = dict(
......@@ -438,7 +438,7 @@ class TestCenterCrop:
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
np.array([[[0, 50, 0]], [[100, 150, 1]], [[200, 225, 0]]])).all()
# test CenterCrop when crop_height is smaller than img_height
transform = dict(
......@@ -457,7 +457,7 @@ class TestCenterCrop:
150]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 0, 1]], [[200, 75, 1]], [[300, 150, 1]]])).all()
np.array([[[20, 0, 0]], [[200, 75, 1]], [[300, 150, 0]]])).all()
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment