Unverified Commit bf48ca03 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Fix] Add `swap_labe_pairs` in `RandomFlip` (#2332)

* [Fix] Add `swap_labe_pairs` in `RandomFlip`

* [Fix] Add `swap_labe_pairs` in `RandomFlip`

* add reference info

* add swap_label_pairs in results

* revise according to comments

* revise according to comments

* revise according to comments

* docstring

* docstring
parent a4c82617
...@@ -1025,21 +1025,25 @@ class RandomFlip(BaseTransform): ...@@ -1025,21 +1025,25 @@ class RandomFlip(BaseTransform):
- flip - flip
- flip_direction - flip_direction
- swap_seg_labels (optional)
Args: Args:
prob (float | list[float], optional): The flipping probability. prob (float | list[float], optional): The flipping probability.
Defaults to None. Defaults to None.
direction(str | list[str]): The flipping direction. Options direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'. corresponding direction. Defaults to 'horizontal'.
swap_seg_labels (list, optional): The label pair need to be swapped
for ground truth, like 'left arm' and 'right arm' need to be
swapped after horizontal flipping. For example, ``[(1, 5)]``,
where 1/5 is the label of the left/right arm. Defaults to None.
""" """
def __init__( def __init__(self,
self, prob: Optional[Union[float, Iterable[float]]] = None,
prob: Optional[Union[float, Iterable[float]]] = None, direction: Union[str, Sequence[Optional[str]]] = 'horizontal',
direction: Union[str, swap_seg_labels: Optional[Sequence] = None) -> None:
Sequence[Optional[str]]] = 'horizontal') -> None:
if isinstance(prob, list): if isinstance(prob, list):
assert mmengine.is_list_of(prob, float) assert mmengine.is_list_of(prob, float)
assert 0 <= sum(prob) <= 1 assert 0 <= sum(prob) <= 1
...@@ -1049,6 +1053,7 @@ class RandomFlip(BaseTransform): ...@@ -1049,6 +1053,7 @@ class RandomFlip(BaseTransform):
raise ValueError(f'probs must be float or list of float, but \ raise ValueError(f'probs must be float or list of float, but \
got `{type(prob)}`.') got `{type(prob)}`.')
self.prob = prob self.prob = prob
self.swap_seg_labels = swap_seg_labels
valid_directions = ['horizontal', 'vertical', 'diagonal'] valid_directions = ['horizontal', 'vertical', 'diagonal']
if isinstance(direction, str): if isinstance(direction, str):
...@@ -1064,8 +1069,8 @@ class RandomFlip(BaseTransform): ...@@ -1064,8 +1069,8 @@ class RandomFlip(BaseTransform):
if isinstance(prob, list): if isinstance(prob, list):
assert len(prob) == len(self.direction) assert len(prob) == len(self.direction)
def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int], def _flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray: direction: str) -> np.ndarray:
"""Flip bboxes horizontally. """Flip bboxes horizontally.
Args: Args:
...@@ -1096,8 +1101,12 @@ class RandomFlip(BaseTransform): ...@@ -1096,8 +1101,12 @@ class RandomFlip(BaseTransform):
or 'diagonal', but got '{direction}'") or 'diagonal', but got '{direction}'")
return flipped return flipped
def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int], def _flip_keypoints(
direction: str) -> np.ndarray: self,
keypoints: np.ndarray,
img_shape: Tuple[int, int],
direction: str,
) -> np.ndarray:
"""Flip keypoints horizontally, vertically or diagonally. """Flip keypoints horizontally, vertically or diagonally.
Args: Args:
...@@ -1127,6 +1136,33 @@ class RandomFlip(BaseTransform): ...@@ -1127,6 +1136,33 @@ class RandomFlip(BaseTransform):
flipped = np.concatenate([keypoints, meta_info], axis=-1) flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped return flipped
def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray:
"""Flip segmentation map horizontally, vertically or diagonally.
Args:
seg_map (numpy.ndarray): segmentation map, shape (H, W).
direction (str): Flip direction. Options are 'horizontal',
'vertical'.
Returns:
numpy.ndarray: Flipped segmentation map.
"""
seg_map = mmcv.imflip(seg_map, direction=direction)
if self.swap_seg_labels is not None:
# to handle datasets with left/right annotations
# like 'Left-arm' and 'Right-arm' in LIP dataset
# Modified from https://github.com/openseg-group/openseg.pytorch/blob/master/lib/datasets/tools/cv2_aug_transforms.py # noqa:E501
# Licensed under MIT license
temp = seg_map.copy()
assert isinstance(self.swap_seg_labels, (tuple, list))
for pair in self.swap_seg_labels:
assert isinstance(pair, (tuple, list)) and len(pair) == 2, \
'swap_seg_labels must be a sequence with pair, but got ' \
f'{self.swap_seg_labels}.'
seg_map[temp == pair[0]] = pair[1]
seg_map[temp == pair[1]] = pair[0]
return seg_map
@cache_randomness @cache_randomness
def _choose_direction(self) -> str: def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`""" """Choose the flip direction according to `prob` and `direction`"""
...@@ -1162,19 +1198,20 @@ class RandomFlip(BaseTransform): ...@@ -1162,19 +1198,20 @@ class RandomFlip(BaseTransform):
# flip bboxes # flip bboxes
if results.get('gt_bboxes', None) is not None: if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'] = self.flip_bbox(results['gt_bboxes'], results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'],
img_shape, img_shape,
results['flip_direction']) results['flip_direction'])
# flip keypoints # flip keypoints
if results.get('gt_keypoints', None) is not None: if results.get('gt_keypoints', None) is not None:
results['gt_keypoints'] = self.flip_keypoints( results['gt_keypoints'] = self._flip_keypoints(
results['gt_keypoints'], img_shape, results['flip_direction']) results['gt_keypoints'], img_shape, results['flip_direction'])
# flip segs # flip seg map
if results.get('gt_seg_map', None) is not None: if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = mmcv.imflip( results['gt_seg_map'] = self._flip_seg_map(
results['gt_seg_map'], direction=results['flip_direction']) results['gt_seg_map'], direction=results['flip_direction'])
results['swap_seg_labels'] = self.swap_seg_labels
def _flip_on_direction(self, results: dict) -> None: def _flip_on_direction(self, results: dict) -> None:
"""Function to flip images, bounding boxes, semantic segmentation map """Function to flip images, bounding boxes, semantic segmentation map
......
...@@ -777,7 +777,9 @@ class TestRandomFlip: ...@@ -777,7 +777,9 @@ class TestRandomFlip:
'img': np.random.random((224, 224, 3)), 'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]), 'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]), 'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_seg_map': np.random.random((224, 224, 3)) # seg map flip is irrelative with image, so there is no requirement
# that gt_set_map of test data matches image.
'gt_seg_map': np.array([[0, 1], [2, 3]])
} }
# horizontal flip # horizontal flip
...@@ -785,41 +787,65 @@ class TestRandomFlip: ...@@ -785,41 +787,65 @@ class TestRandomFlip:
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224, assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all() 101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()
# diagnal flip # diagonal flip
TRANSFORMS = RandomFlip([1.0], ['diagonal']) TRANSFORMS = RandomFlip([1.0], ['diagonal'])
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 123, 224, assert (results_update['gt_bboxes'] == np.array([[124, 123, 224,
223]])).all() 223]])).all()
assert (results_update['gt_seg_map'] == np.array([[3, 2], [1,
0]])).all()
# vertical flip # vertical flip
TRANSFORMS = RandomFlip([1.0], ['vertical']) TRANSFORMS = RandomFlip([1.0], ['vertical'])
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 123, 100, assert (results_update['gt_bboxes'] == np.array([[0, 123, 100,
223]])).all() 223]])).all()
assert (results_update['gt_seg_map'] == np.array([[2, 3], [0,
1]])).all()
# horizontal flip when direction is None # horizontal flip when direction is None
TRANSFORMS = RandomFlip(1.0) TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224, assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all() 101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()
# horizontal flip and swap label pair
TRANSFORMS = RandomFlip([1.0], ['horizontal'],
swap_seg_labels=[[0, 1]])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_seg_map'] == np.array([[0, 1], [3,
2]])).all()
assert results_update['swap_seg_labels'] == [[0, 1]]
TRANSFORMS = RandomFlip(0.0) TRANSFORMS = RandomFlip(0.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 1, 100, assert (results_update['gt_bboxes'] == np.array([[0, 1, 100,
101]])).all() 101]])).all()
assert (results_update['gt_seg_map'] == np.array([[0, 1], [2,
3]])).all()
# flip direction is invalid in bbox flip # flip direction is invalid in bbox flip
with pytest.raises(ValueError): with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0) TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_bbox(results['gt_bboxes'], results_update = TRANSFORMS._flip_bbox(results['gt_bboxes'],
(224, 224), 'invalid') (224, 224), 'invalid')
# flip direction is invalid in keypoints flip # flip direction is invalid in keypoints flip
with pytest.raises(ValueError): with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0) TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_keypoints(results['gt_keypoints'], results_update = TRANSFORMS._flip_keypoints(
(224, 224), 'invalid') results['gt_keypoints'], (224, 224), 'invalid')
# swap pair is invalid
with pytest.raises(AssertionError):
TRANSFORMS = RandomFlip(1.0, swap_seg_labels='invalid')
results_update = TRANSFORMS._flip_seg_map(results['gt_seg_map'],
'horizontal')
def test_repr(self): def test_repr(self):
TRANSFORMS = RandomFlip(0.1) TRANSFORMS = RandomFlip(0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment