Commit ea84b674 authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

update scales and img_shape (#1871)

parent 30b38446
......@@ -438,8 +438,7 @@ class CenterCrop(BaseTransform):
Modified Keys:
- img
- height
- width
- img_shape
- gt_seg_map (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
......@@ -499,10 +498,9 @@ class CenterCrop(BaseTransform):
"""
if results.get('img', None) is not None:
img = mmcv.imcrop(results['img'], bboxes=bboxes)
img_shape = img.shape
img_shape = img.shape[:2]
results['img'] = img
results['height'] = img_shape[0]
results['width'] = img_shape[1]
results['img_shape'] = img_shape
results['pad_shape'] = img_shape
def _crop_seg_map(self, results: dict, bboxes: np.ndarray) -> None:
......@@ -725,7 +723,7 @@ class MultiScaleFlipAug(BaseTransform):
dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 400), (1333, 800)],
scales=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Normalize', **img_norm_cfg),
......@@ -734,7 +732,7 @@ class MultiScaleFlipAug(BaseTransform):
dict(type='Collect', keys=['img'])
])
``results`` will be resized using all the sizes in ``img_scale``.
``results`` will be resized using all the sizes in ``scales``.
If ``flip`` is True, then flipped results will also be added into output
list.
......@@ -769,7 +767,7 @@ class MultiScaleFlipAug(BaseTransform):
Args:
transforms (list[dict]): Transforms to be applied to each resized
and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scales (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
allow_flip (bool): Whether apply flip augmentation. Defaults to False.
......@@ -787,7 +785,7 @@ class MultiScaleFlipAug(BaseTransform):
def __init__(
self,
transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
scales: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None,
allow_flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal',
......@@ -797,17 +795,16 @@ class MultiScaleFlipAug(BaseTransform):
super().__init__()
self.transforms = Compose(transforms) # type: ignore
if img_scale is not None:
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
if scales is not None:
self.scales = scales if isinstance(scales, list) else [scales]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
assert mmcv.is_list_of(self.scales, tuple)
else:
# if ``img_scale`` and ``scale_factor`` both be ``None``
# if ``scales`` and ``scale_factor`` both be ``None``
if scale_factor is None:
self.img_scale = [1.]
self.scales = [1.]
else:
self.img_scale = scale_factor if isinstance(
self.scales = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'
......@@ -838,7 +835,7 @@ class MultiScaleFlipAug(BaseTransform):
if self.allow_flip:
flip_args += [(True, direction)
for direction in self.flip_direction]
for scale in self.img_scale:
for scale in self.scales:
for flip, direction in flip_args:
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update({self.scale_key: scale})
......@@ -864,7 +861,7 @@ class MultiScaleFlipAug(BaseTransform):
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}'
repr_str += f', img_scale={self.img_scale}'
repr_str += f', scales={self.scales}'
repr_str += f', allow_flip={self.allow_flip}'
repr_str += f', flip_direction={self.flip_direction}'
return repr_str
......@@ -933,8 +930,8 @@ class RandomChoiceResize(BaseTransform):
"""Randomly select an scale from given candidates.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
(tuple, int): Returns a tuple ``(scale, scale_dix)``,
where ``scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
......
......@@ -291,8 +291,7 @@ class TestCenterCrop:
transform = dict(type='CenterCrop', crop_size=224)
center_crop_module = TRANSFORMS.build(transform)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert results['img_shape'] == (224, 224)
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all()
......@@ -309,8 +308,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 224
assert results['width'] == 224
assert results['img_shape'] == (224, 224)
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all()
......@@ -327,8 +325,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 256
assert results['width'] == 224
assert results['img_shape'] == (256, 224)
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[22:278,
88:312]).all()
......@@ -346,8 +343,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert results['img_shape'] == (300, 400)
assert (results['img'] == self.original_img).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
......@@ -364,8 +360,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 300
assert results['width'] == 400
assert results['img_shape'] == (300, 400)
assert (results['img'] == self.original_img).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
......@@ -385,8 +380,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert results['img_shape'] == (600, 200)
assert results['img'].shape[:2] == results['gt_seg_map'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_seg_map'][300:600, 100:300] == 255).all()
......@@ -409,8 +403,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert results['img_shape'] == (600, 200)
assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_seg_map'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'],
......@@ -427,8 +420,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height
assert results['width'] == img_width // 2
assert results['img_shape'] == (img_height, img_width // 2)
assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[:,
100:300]).all()
......@@ -446,8 +438,7 @@ class TestCenterCrop:
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == img_height // 2
assert results['width'] == img_width
assert results['img_shape'] == (img_height // 2, img_width)
assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[75:225,
...]).all()
......@@ -557,17 +548,17 @@ class TestMultiScaleFlipAug:
cls.original_img = copy.deepcopy(cls.img)
def test_error(self):
# test assertion if img_scale is not tuple or list of tuple
# test assertion if scales is not tuple or list of tuple
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=[1333, 800], transforms=[])
type='MultiScaleFlipAug', scales=[1333, 800], transforms=[])
TRANSFORMS.build(transform)
# test assertion if flip_direction is not str or list of str
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug',
img_scale=[(1333, 800)],
scales=[(1333, 800)],
flip_direction=1,
transforms=[])
TRANSFORMS.build(transform)
......@@ -579,7 +570,7 @@ class TestMultiScaleFlipAug:
transform = dict(
type='MultiScaleFlipAug',
transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
......@@ -592,7 +583,7 @@ class TestMultiScaleFlipAug:
transform = dict(
type='MultiScaleFlipAug',
transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
......@@ -615,7 +606,7 @@ class TestMultiScaleFlipAug:
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)],
scales=[(1333, 800), (800, 600), (640, 480)],
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment