Commit 0a5b4125 authored by Yuan Liu's avatar Yuan Liu Committed by zhouzaida
Browse files

[Feature]: Reformat resize config (#1826)



* [Feature]: Add cache to random func in data transform

* [Fix]: Fix lint

* [Fix]: Fix cache decorate problem

* [Refactor]: Initialize Resize with config

* [Refactor]: Move other resize config into Resize config

* [Fix]: Scale can not be None in RandomResize

* [Fix]: Change semantic seg to gt seg map

* [Fix]: Delete unnecessary assert

* [Fix]: Fix docstring

* [Fix]: Add double quot to Resize in config

* [Fix]: Fix the return type

* [Fix]: Improve docstring

* [Fix]: Specify the order of width and height for ratio range

* [Fix]: Specify resize order
Co-authored-by: default avatarYour <you@example.com>
parent 2844a116
...@@ -9,6 +9,7 @@ import mmcv ...@@ -9,6 +9,7 @@ import mmcv
from mmcv.image.geometric import _scale_size from mmcv.image.geometric import _scale_size
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .utils import cacheable_method
from .wrappers import Compose from .wrappers import Compose
Number = Union[int, float] Number = Union[int, float]
...@@ -548,9 +549,9 @@ class CenterCrop(BaseTransform): ...@@ -548,9 +549,9 @@ class CenterCrop(BaseTransform):
# set gt_kepoints out of the result image invisible # set gt_kepoints out of the result image invisible
height, width = results['img'].shape[:2] height, width = results['img'].shape[:2]
valid_pos = (gt_keypoints[:, :, 0] >= valid_pos = (gt_keypoints[:, :, 0] >=
0) * (gt_keypoints[:, :, 0] < 0) * (gt_keypoints[:, :, 0] <
width) * (gt_keypoints[:, :, 1] >= 0) * ( width) * (gt_keypoints[:, :, 1] >= 0) * (
gt_keypoints[:, :, 1] < height) gt_keypoints[:, :, 1] < height)
gt_keypoints[:, :, 2] = np.where(valid_pos, gt_keypoints[:, :, 2], gt_keypoints[:, :, 2] = np.where(valid_pos, gt_keypoints[:, :, 2],
0) 0)
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0, gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
...@@ -1129,6 +1130,7 @@ class RandomFlip(BaseTransform): ...@@ -1129,6 +1130,7 @@ class RandomFlip(BaseTransform):
flipped = np.concatenate([keypoints, meta_info], axis=-1) flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped return flipped
@cacheable_method
def _choose_direction(self) -> str: def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`""" """Choose the flip direction according to `prob` and `direction`"""
if isinstance(self.direction, if isinstance(self.direction,
...@@ -1173,10 +1175,9 @@ class RandomFlip(BaseTransform): ...@@ -1173,10 +1175,9 @@ class RandomFlip(BaseTransform):
results['gt_keypoints'], img_shape, results['flip_direction']) results['gt_keypoints'], img_shape, results['flip_direction'])
# flip segs # flip segs
if results.get('gt_semantic_seg', None) is not None: if results.get('gt_seg_map', None) is not None:
results['gt_semantic_seg'] = mmcv.imflip( results['gt_seg_map'] = mmcv.imflip(
results['gt_semantic_seg'], results['gt_seg_map'], direction=results['flip_direction'])
direction=results['flip_direction'])
def _flip_on_direction(self, results: dict) -> None: def _flip_on_direction(self, results: dict) -> None:
"""Function to flip images, bounding boxes, semantic segmentation map """Function to flip images, bounding boxes, semantic segmentation map
...@@ -1217,18 +1218,20 @@ class RandomFlip(BaseTransform): ...@@ -1217,18 +1218,20 @@ class RandomFlip(BaseTransform):
class RandomResize(BaseTransform): class RandomResize(BaseTransform):
"""Random resize images & bbox & keypoints. """Random resize images & bbox & keypoints.
Added or updated keys: scale, scale_factor, keep_ratio, img, height, width,
gt_bboxes, gt_semantic_seg, and gt_keypoints.
How to choose the target scale to resize the image will follow the rules How to choose the target scale to resize the image will follow the rules
below: below:
- if `scale` is a list of tuple, the first value of the target scale is - if ``scale`` is a list of tuple, the first value of the target scale is
sampled from [`scale[0][0]`, `scale[1][0]`] uniformally and the second sampled from [``scale[0][0]``, ``scale[1][0]``] uniformally and the
value of the target scale is sampled from [`scale[0][1]`, `scale[1][1]`] second value of the target scale is sampled from
uniformally. [``scale[0][1]``, ``scale[1][1]``] uniformally. Following the resize
- if `scale` is a tuple, the first and second values of the target scale order of weight and height in cv2, scale[i][0] is for width, and
is equal to the first and second values of `scale` multiplied by a value scale[i][1] is for height.
sampled from [`ratio_range[0]`, `ratio_range[1]`] uniformally. - if ``scale`` is a tuple, the first and second values of the target scale
is equal to the first and second values of ``scale`` multiplied by a
value sampled from [``ratio_range[0]``, ``ratio_range[1]``] uniformally.
Following the resize order of weight and height in cv2, ratio_range[0] is
for width, and ratio_range[1] is for height.
Required Keys: Required Keys:
...@@ -1251,50 +1254,37 @@ class RandomResize(BaseTransform): ...@@ -1251,50 +1254,37 @@ class RandomResize(BaseTransform):
- keep_ratio - keep_ratio
Args: Args:
scale (tuple or list[tuple], optional): Images scales for resizing. scale (tuple or list[tuple]): Images scales for resizing.
Defaults to None. Defaults to None.
ratio_range (tuple[float], optional): (min_ratio, max_ratio). ratio_range (tuple[float], optional): (min_ratio, max_ratio).
Defaults to None. Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the resize_cfg (dict): Config to initialize a ``Resize`` transform.
image. Defaults to True. Defaults to dict(type='Resize', keep_ratio=True,
clip_object_border (bool): Whether to clip the objects clip_object_border=True, backend='cv2', interpolation='bilinear').
outside the border of the image. In some dataset like MOT17, the
gt bboxes are allowed to cross the border of images. Therefore,
we don't need to clip the gt bboxes in these cases.
Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): How to interpolate the original image when
resizing. Defaults to 'bilinear'.
""" """
def __init__(self, def __init__(
scale: Union[Tuple[int, int], List[Tuple[int, int]]] = None, self,
ratio_range: Tuple[float, float] = None, scale: Union[Tuple[int, int], List[Tuple[int, int]]],
keep_ratio: bool = True, ratio_range: Tuple[float, float] = None,
clip_object_border: bool = True, resize_cfg: dict = dict(
backend: str = 'cv2', type='Resize',
interpolation: str = 'bilinear') -> None: keep_ratio=True,
clip_object_border=True,
assert scale is not None backend='cv2',
interpolation='bilinear')
) -> None:
self.scale = scale self.scale = scale
self.ratio_range = ratio_range self.ratio_range = ratio_range
self.keep_ratio = keep_ratio self.resize_cfg = resize_cfg
self.clip_object_border = clip_object_border
self.backend = backend
self.interpolation = interpolation
# create a empty Reisize object # create a empty Reisize object
self.resize = Resize(0) self.resize_cfg.update(dict(scale=0))
self.resize.keep_ratio = keep_ratio self.resize = TRANSFORMS.build(self.resize_cfg)
self.resize.clip_object_border = clip_object_border
self.resize.backend = backend
self.resize.interpolation = interpolation
@staticmethod @staticmethod
def _random_sample(scales: Sequence[Tuple[int, int]]) -> Tuple[int, int]: def _random_sample(scales: Sequence[Tuple[int, int]]) -> tuple:
"""Private function to randomly sample a scale from a list of tuples. """Private function to randomly sample a scale from a list of tuples.
Args: Args:
...@@ -1302,7 +1292,7 @@ class RandomResize(BaseTransform): ...@@ -1302,7 +1292,7 @@ class RandomResize(BaseTransform):
There must be two tuples in scales, which specify the lower There must be two tuples in scales, which specify the lower
and upper bound of image scales. and upper bound of image scales.
Returns: Returns:
tuple: Returns the target scale. tuple: The targeted scale of the image to be resized.
""" """
assert mmcv.is_list_of(scales, tuple) and len(scales) == 2 assert mmcv.is_list_of(scales, tuple) and len(scales) == 2
...@@ -1314,8 +1304,8 @@ class RandomResize(BaseTransform): ...@@ -1314,8 +1304,8 @@ class RandomResize(BaseTransform):
return scale return scale
@staticmethod @staticmethod
def _random_sample_ratio( def _random_sample_ratio(scale: tuple, ratio_range: Tuple[float,
scale: tuple, ratio_range: Tuple[float, float]) -> Tuple[int, int]: float]) -> tuple:
"""Private function to randomly sample a scale from a tuple. """Private function to randomly sample a scale from a tuple.
A ratio will be randomly sampled from the range specified by A ratio will be randomly sampled from the range specified by
...@@ -1326,7 +1316,7 @@ class RandomResize(BaseTransform): ...@@ -1326,7 +1316,7 @@ class RandomResize(BaseTransform):
ratio_range (tuple[float]): The minimum and maximum ratio to scale ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``scale``. the ``scale``.
Returns: Returns:
tuple: Returns the target scale. tuple: The targeted scale of the image to be resized.
""" """
assert isinstance(scale, tuple) and len(scale) == 2 assert isinstance(scale, tuple) and len(scale) == 2
...@@ -1336,15 +1326,13 @@ class RandomResize(BaseTransform): ...@@ -1336,15 +1326,13 @@ class RandomResize(BaseTransform):
scale = int(scale[0] * ratio), int(scale[1] * ratio) scale = int(scale[0] * ratio), int(scale[1] * ratio)
return scale return scale
def _random_scale(self, results: dict) -> None: @cacheable_method
def _random_scale(self) -> tuple:
"""Private function to randomly sample an scale according to the type """Private function to randomly sample an scale according to the type
of `scale`. of ``scale``.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns: Returns:
dict: One new key 'scale`is added into ``results``, tuple: The targeted scale of the image to be resized.
which would be used by subsequent pipelines.
""" """
if isinstance(self.scale, tuple): if isinstance(self.scale, tuple):
...@@ -1357,7 +1345,7 @@ class RandomResize(BaseTransform): ...@@ -1357,7 +1345,7 @@ class RandomResize(BaseTransform):
raise NotImplementedError(f"Do not support sampling function \ raise NotImplementedError(f"Do not support sampling function \
for '{self.scale}'") for '{self.scale}'")
results['scale'] = scale return scale
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes, semantic """Transform function to resize images, bounding boxes, semantic
...@@ -1366,11 +1354,11 @@ class RandomResize(BaseTransform): ...@@ -1366,11 +1354,11 @@ class RandomResize(BaseTransform):
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg', dict: Resized results, ``img``, ``gt_bboxes``, ``gt_semantic_seg``,
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width', ``gt_keypoints``, ``scale``, ``scale_factor``, ``height``,
and 'keep_ratio' keys are updated in result dict. ``width``, and ``keep_ratio`` keys are updated in result dict.
""" """
self._random_scale(results) results['scale'] = self._random_scale()
self.resize.scale = results['scale'] self.resize.scale = results['scale']
results = self.resize.transform(results) results = self.resize.transform(results)
return results return results
...@@ -1379,8 +1367,5 @@ class RandomResize(BaseTransform): ...@@ -1379,8 +1367,5 @@ class RandomResize(BaseTransform):
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, ' repr_str += f'(scale={self.scale}, '
repr_str += f'ratio_range={self.ratio_range}, ' repr_str += f'ratio_range={self.ratio_range}, '
repr_str += f'keep_ratio={self.keep_ratio}, ' repr_str += f'resize_cfg={self.resize_cfg})'
repr_str += f'bbox_clip_border={self.clip_object_border}, '
repr_str += f'backend={self.backend}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str return repr_str
...@@ -792,7 +792,7 @@ class TestRandomFlip: ...@@ -792,7 +792,7 @@ class TestRandomFlip:
'img': np.random.random((224, 224, 3)), 'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]), 'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]), 'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_semantic_seg': np.random.random((224, 224, 3)) 'gt_seg_map': np.random.random((224, 224, 3))
} }
# horizontal flip # horizontal flip
...@@ -877,9 +877,10 @@ class TestRandomResize: ...@@ -877,9 +877,10 @@ class TestRandomResize:
'gt_bboxes': np.array([[0, 0, 112, 112]]), 'gt_bboxes': np.array([[0, 0, 112, 112]]),
'gt_keypoints': np.array([[[112, 112]]]) 'gt_keypoints': np.array([[[112, 112]]])
} }
# import pdb
# pdb.set_trace() TRANSFORMS = RandomResize(
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=True) (224, 224), (1.0, 2.0),
resize_cfg=dict(type='Resize', keep_ratio=True))
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert 224 <= results_update['height'] assert 224 <= results_update['height']
assert 448 >= results_update['height'] assert 448 >= results_update['height']
...@@ -890,13 +891,17 @@ class TestRandomResize: ...@@ -890,13 +891,17 @@ class TestRandomResize:
assert results['gt_bboxes'][0][2] <= 112 assert results['gt_bboxes'][0][2] <= 112
# keep ratio is False # keep ratio is False
TRANSFORMS = RandomResize((224, 224), (1.0, 2.0), keep_ratio=False) TRANSFORMS = RandomResize(
(224, 224), (1.0, 2.0),
resize_cfg=dict(type='Resize', keep_ratio=False))
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
# choose target scale from init when override is False and scale is a # choose target scale from init when override is False and scale is a
# list of tuples # list of tuples
results = {} results = {}
TRANSFORMS = RandomResize([(224, 448), (112, 224)], keep_ratio=True) TRANSFORMS = RandomResize([(224, 448), (112, 224)],
resize_cfg=dict(
type='Resize', keep_ratio=True))
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert results_update['scale'][0] >= 224 and results_update['scale'][ assert results_update['scale'][0] >= 224 and results_update['scale'][
0] <= 448 0] <= 448
...@@ -907,5 +912,6 @@ class TestRandomResize: ...@@ -907,5 +912,6 @@ class TestRandomResize:
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
results = {} results = {}
TRANSFORMS = RandomResize([(224, 448), [112, 224]], TRANSFORMS = RandomResize([(224, 448), [112, 224]],
keep_ratio=True) resize_cfg=dict(
type='Resize', keep_ratio=True))
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment