Commit ea84b674 authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

update scales and img_shape (#1871)

parent 30b38446
...@@ -438,8 +438,7 @@ class CenterCrop(BaseTransform): ...@@ -438,8 +438,7 @@ class CenterCrop(BaseTransform):
Modified Keys: Modified Keys:
- img - img
- height - img_shape
- width
- gt_seg_map (optional) - gt_seg_map (optional)
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
...@@ -499,10 +498,9 @@ class CenterCrop(BaseTransform): ...@@ -499,10 +498,9 @@ class CenterCrop(BaseTransform):
""" """
if results.get('img', None) is not None: if results.get('img', None) is not None:
img = mmcv.imcrop(results['img'], bboxes=bboxes) img = mmcv.imcrop(results['img'], bboxes=bboxes)
img_shape = img.shape img_shape = img.shape[:2]
results['img'] = img results['img'] = img
results['height'] = img_shape[0] results['img_shape'] = img_shape
results['width'] = img_shape[1]
results['pad_shape'] = img_shape results['pad_shape'] = img_shape
def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None: def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None:
...@@ -725,7 +723,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -725,7 +723,7 @@ class MultiScaleFlipAug(BaseTransform):
dict( dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
img_scale=[(1333, 400), (1333, 800)], scales=[(1333, 400), (1333, 800)],
flip=True, flip=True,
transforms=[ transforms=[
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
...@@ -734,7 +732,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -734,7 +732,7 @@ class MultiScaleFlipAug(BaseTransform):
dict(type='Collect', keys=['img']) dict(type='Collect', keys=['img'])
]) ])
``results`` will be resized using all the sizes in ``img_scale``. ``results`` will be resized using all the sizes in ``scales``.
If ``flip`` is True, then flipped results will also be added into output If ``flip`` is True, then flipped results will also be added into output
list. list.
...@@ -769,7 +767,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -769,7 +767,7 @@ class MultiScaleFlipAug(BaseTransform):
Args: Args:
transforms (list[dict]): Transforms to be applied to each resized transforms (list[dict]): Transforms to be applied to each resized
and flipped data. and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing. scales (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing. scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None. Defaults to None.
allow_flip (bool): Whether apply flip augmentation. Defaults to False. allow_flip (bool): Whether apply flip augmentation. Defaults to False.
...@@ -787,7 +785,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -787,7 +785,7 @@ class MultiScaleFlipAug(BaseTransform):
def __init__( def __init__(
self, self,
transforms: List[dict], transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None, scales: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None, scale_factor: Optional[Union[float, List[float]]] = None,
allow_flip: bool = False, allow_flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal', flip_direction: Union[str, List[str]] = 'horizontal',
...@@ -797,17 +795,16 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -797,17 +795,16 @@ class MultiScaleFlipAug(BaseTransform):
super().__init__() super().__init__()
self.transforms = Compose(transforms) # type: ignore self.transforms = Compose(transforms) # type: ignore
if img_scale is not None: if scales is not None:
self.img_scale = img_scale if isinstance(img_scale, self.scales = scales if isinstance(scales, list) else [scales]
list) else [img_scale]
self.scale_key = 'scale' self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple) assert mmcv.is_list_of(self.scales, tuple)
else: else:
# if ``img_scale`` and ``scale_factor`` both be ``None`` # if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None: if scale_factor is None:
self.img_scale = [1.] self.scales = [1.]
else: else:
self.img_scale = scale_factor if isinstance( self.scales = scale_factor if isinstance(
scale_factor, list) else [scale_factor] scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor' self.scale_key = 'scale_factor'
...@@ -838,7 +835,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -838,7 +835,7 @@ class MultiScaleFlipAug(BaseTransform):
if self.allow_flip: if self.allow_flip:
flip_args += [(True, direction) flip_args += [(True, direction)
for direction in self.flip_direction] for direction in self.flip_direction]
for scale in self.img_scale: for scale in self.scales:
for flip, direction in flip_args: for flip, direction in flip_args:
_resize_cfg = self.resize_cfg.copy() _resize_cfg = self.resize_cfg.copy()
_resize_cfg.update({self.scale_key: scale}) _resize_cfg.update({self.scale_key: scale})
...@@ -864,7 +861,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -864,7 +861,7 @@ class MultiScaleFlipAug(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}' repr_str += f', transforms={self.transforms}'
repr_str += f', img_scale={self.img_scale}' repr_str += f', scales={self.scales}'
repr_str += f', allow_flip={self.allow_flip}' repr_str += f', allow_flip={self.allow_flip}'
repr_str += f', flip_direction={self.flip_direction}' repr_str += f', flip_direction={self.flip_direction}'
return repr_str return repr_str
...@@ -933,8 +930,8 @@ class RandomChoiceResize(BaseTransform): ...@@ -933,8 +930,8 @@ class RandomChoiceResize(BaseTransform):
"""Randomly select an scale from given candidates. """Randomly select an scale from given candidates.
Returns: Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, (tuple, int): Returns a tuple ``(scale, scale_dix)``,
where ``img_scale`` is the selected image scale and where ``scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates. ``scale_idx`` is the selected index in the given candidates.
""" """
......
...@@ -291,8 +291,7 @@ class TestCenterCrop: ...@@ -291,8 +291,7 @@ class TestCenterCrop:
transform = dict(type='CenterCrop', crop_size=224) transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 224 assert results['img_shape'] == (224, 224)
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all() assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262, assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all() 88:312]).all()
...@@ -309,8 +308,7 @@ class TestCenterCrop: ...@@ -309,8 +308,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 224 assert results['img_shape'] == (224, 224)
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all() assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262, assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all() 88:312]).all()
...@@ -327,8 +325,7 @@ class TestCenterCrop: ...@@ -327,8 +325,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 256 assert results['img_shape'] == (256, 224)
assert results['width'] == 224
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all() assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[22:278, assert (results['gt_seg_map'] == self.gt_semantic_map[22:278,
88:312]).all() 88:312]).all()
...@@ -346,8 +343,7 @@ class TestCenterCrop: ...@@ -346,8 +343,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 300 assert results['img_shape'] == (300, 400)
assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all() assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
...@@ -364,8 +360,7 @@ class TestCenterCrop: ...@@ -364,8 +360,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 300 assert results['img_shape'] == (300, 400)
assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all() assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
...@@ -385,8 +380,7 @@ class TestCenterCrop: ...@@ -385,8 +380,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 600 assert results['img_shape'] == (600, 200)
assert results['width'] == 200
assert results['img'].shape[:2] == results['gt_seg_map'].shape assert results['img'].shape[:2] == results['gt_seg_map'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all() assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_seg_map'][300:600, 100:300] == 255).all() assert (results['gt_seg_map'][300:600, 100:300] == 255).all()
...@@ -409,8 +403,7 @@ class TestCenterCrop: ...@@ -409,8 +403,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 600 assert results['img_shape'] == (600, 200)
assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all() assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_seg_map'][300:600, 100:300] == 33).all() assert (results['gt_seg_map'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
...@@ -427,8 +420,7 @@ class TestCenterCrop: ...@@ -427,8 +420,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == img_height assert results['img_shape'] == (img_height, img_width // 2)
assert results['width'] == img_width // 2
assert (results['img'] == self.original_img[:, 100:300, ...]).all() assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[:, assert (results['gt_seg_map'] == self.gt_semantic_map[:,
100:300]).all() 100:300]).all()
...@@ -446,8 +438,7 @@ class TestCenterCrop: ...@@ -446,8 +438,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == img_height // 2 assert results['img_shape'] == (img_height // 2, img_width)
assert results['width'] == img_width
assert (results['img'] == self.original_img[75:225, ...]).all() assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[75:225, assert (results['gt_seg_map'] == self.gt_semantic_map[75:225,
...]).all() ...]).all()
...@@ -557,17 +548,17 @@ class TestMultiScaleFlipAug: ...@@ -557,17 +548,17 @@ class TestMultiScaleFlipAug:
cls.original_img = copy.deepcopy(cls.img) cls.original_img = copy.deepcopy(cls.img)
def test_error(self): def test_error(self):
# test assertion if img_scale is not tuple or list of tuple # test assertion if scales is not tuple or list of tuple
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
transform = dict( transform = dict(
type='MultiScaleFlipAug', img_scale=[1333, 800], transforms=[]) type='MultiScaleFlipAug', scales=[1333, 800], transforms=[])
TRANSFORMS.build(transform) TRANSFORMS.build(transform)
# test assertion if flip_direction is not str or list of str # test assertion if flip_direction is not str or list of str
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
img_scale=[(1333, 800)], scales=[(1333, 800)],
flip_direction=1, flip_direction=1,
transforms=[]) transforms=[])
TRANSFORMS.build(transform) TRANSFORMS.build(transform)
...@@ -579,7 +570,7 @@ class TestMultiScaleFlipAug: ...@@ -579,7 +570,7 @@ class TestMultiScaleFlipAug:
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[dict(type='MockPackTaskInputs')], transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)], scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
...@@ -592,7 +583,7 @@ class TestMultiScaleFlipAug: ...@@ -592,7 +583,7 @@ class TestMultiScaleFlipAug:
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[dict(type='MockPackTaskInputs')], transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)], scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=False, allow_flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
...@@ -615,7 +606,7 @@ class TestMultiScaleFlipAug: ...@@ -615,7 +606,7 @@ class TestMultiScaleFlipAug:
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=transforms_cfg, transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)], scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment