Unverified Commit fadd915c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] _flip_keypoint of RandomFlip returns a wrong result (#2527)

parent 4ae327f4
...@@ -1203,7 +1203,8 @@ class RandomFlip(BaseTransform): ...@@ -1203,7 +1203,8 @@ class RandomFlip(BaseTransform):
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
img_shape (tuple[int]): Image shape (height, width) img_shape (tuple[int]): Image shape (height, width)
direction (str): Flip direction. Options are 'horizontal', direction (str): Flip direction. Options are 'horizontal',
'vertical'. 'vertical', and 'diagonal'.
Returns: Returns:
numpy.ndarray: Flipped bounding boxes. numpy.ndarray: Flipped bounding boxes.
""" """
...@@ -1239,7 +1240,8 @@ class RandomFlip(BaseTransform): ...@@ -1239,7 +1240,8 @@ class RandomFlip(BaseTransform):
keypoints (numpy.ndarray): Keypoints, shape (..., 2) keypoints (numpy.ndarray): Keypoints, shape (..., 2)
img_shape (tuple[int]): Image shape (height, width) img_shape (tuple[int]): Image shape (height, width)
direction (str): Flip direction. Options are 'horizontal', direction (str): Flip direction. Options are 'horizontal',
'vertical'. 'vertical', and 'diagonal'.
Returns: Returns:
numpy.ndarray: Flipped keypoints. numpy.ndarray: Flipped keypoints.
""" """
...@@ -1259,7 +1261,7 @@ class RandomFlip(BaseTransform): ...@@ -1259,7 +1261,7 @@ class RandomFlip(BaseTransform):
raise ValueError( raise ValueError(
f"Flipping direction must be 'horizontal', 'vertical', \ f"Flipping direction must be 'horizontal', 'vertical', \
or 'diagonal', but got '{direction}'") or 'diagonal', but got '{direction}'")
flipped = np.concatenate([keypoints, meta_info], axis=-1) flipped = np.concatenate([flipped, meta_info], axis=-1)
return flipped return flipped
def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray: def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray:
...@@ -1357,6 +1359,7 @@ class RandomFlip(BaseTransform): ...@@ -1357,6 +1359,7 @@ class RandomFlip(BaseTransform):
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Flipped results, 'img', 'gt_bboxes', 'gt_seg_map', dict: Flipped results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'flip', and 'flip_direction' keys are 'gt_keypoints', 'flip', and 'flip_direction' keys are
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment