"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "06789888408ba3a61817782b98c721e8f96c66d1"
Commit 2f85d781 authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Enhancement] Enhance CenterCrop (#1765)

* enhance centercrop and adjust crop size to (w, h)

* fix comments

* update required keys and docstring
parent 2619aa9c
...@@ -415,14 +415,16 @@ class Pad(BaseTransform): ...@@ -415,14 +415,16 @@ class Pad(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class CenterCrop(BaseTransform): class CenterCrop(BaseTransform):
"""Crop the center of the image and segmentation masks. If the crop area """Crop the center of the image, segmentation masks, bounding boxes and key
exceeds the original image and ``pad_mode`` is not None, the original image points. If the crop area exceeds the original image and ``pad_mode`` is not
will be padded before cropping. None, the original image will be padded before cropping.
Required Keys: Required Keys:
- img - img
- gt_semantic_seg (optional) - gt_semantic_seg (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
Modified Keys: Modified Keys:
...@@ -430,6 +432,8 @@ class CenterCrop(BaseTransform): ...@@ -430,6 +432,8 @@ class CenterCrop(BaseTransform):
- height - height
- width - width
- gt_semantic_seg (optional) - gt_semantic_seg (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
Added Key: Added Key:
...@@ -438,8 +442,8 @@ class CenterCrop(BaseTransform): ...@@ -438,8 +442,8 @@ class CenterCrop(BaseTransform):
Args: Args:
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping with the format of (w, h). If set to an integer, then cropping
height and width are equal to this integer. width and height are equal to this integer.
pad_val (Union[Number, Dict[str, Number]]): A dict for pad_val (Union[Number, Dict[str, Number]]): A dict for
padding value. To specify how to set this argument, please see padding value. To specify how to set this argument, please see
the docstring of class ``Pad``. Defaults to the docstring of class ``Pad``. Defaults to
...@@ -449,6 +453,11 @@ class CenterCrop(BaseTransform): ...@@ -449,6 +453,11 @@ class CenterCrop(BaseTransform):
docstring of class ``Pad``. Defaults to 'constant'. docstring of class ``Pad``. Defaults to 'constant'.
pad_cfg (str): Base config for padding. Defaults to pad_cfg (str): Base config for padding. Defaults to
``dict(type='Pad')``. ``dict(type='Pad')``.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the
gt bboxes are allowed to cross the border of images. Therefore,
we don't need to clip the gt bboxes in these cases.
Defaults to True.
""" """
def __init__( def __init__(
...@@ -456,7 +465,8 @@ class CenterCrop(BaseTransform): ...@@ -456,7 +465,8 @@ class CenterCrop(BaseTransform):
crop_size: Union[int, Tuple[int, int]], crop_size: Union[int, Tuple[int, int]],
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255), pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
pad_mode: Optional[str] = None, pad_mode: Optional[str] = None,
pad_cfg: dict = dict(type='Pad') pad_cfg: dict = dict(type='Pad'),
clip_object_border: bool = True,
) -> None: # flake8: noqa ) -> None: # flake8: noqa
super().__init__() super().__init__()
assert isinstance(crop_size, int) or ( assert isinstance(crop_size, int) or (
...@@ -471,6 +481,7 @@ class CenterCrop(BaseTransform): ...@@ -471,6 +481,7 @@ class CenterCrop(BaseTransform):
self.pad_val = pad_val self.pad_val = pad_val
self.pad_mode = pad_mode self.pad_mode = pad_mode
self.pad_cfg = pad_cfg self.pad_cfg = pad_cfg
self.clip_object_border = clip_object_border
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None: def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
"""Crop image. """Crop image.
...@@ -498,6 +509,47 @@ class CenterCrop(BaseTransform): ...@@ -498,6 +509,47 @@ class CenterCrop(BaseTransform):
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes) img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
results['gt_semantic_seg'] = img results['gt_semantic_seg'] = img
def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None:
"""Update bounding boxes according to CenterCrop.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if 'gt_bboxes' in results:
offset_w = bboxes[0]
offset_h = bboxes[1]
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h])
# gt_bboxes has shape (num_gts, 4) in (tl_x, tl_y, br_x, br_y)
# order.
gt_bboxes = results['gt_bboxes'] - bbox_offset
if self.clip_object_border:
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0,
results['img'].shape[1])
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0,
results['img'].shape[0])
results['gt_bboxes'] = gt_bboxes
def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None:
"""Update key points according to CenterCrop.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if 'gt_keypoints' in results:
offset_w = bboxes[0]
offset_h = bboxes[1]
keypoints_offset = np.array([offset_w, offset_h, 0])
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
# NK = number of points per object
gt_keypoints = results['gt_keypoints'] - keypoints_offset
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
results['img'].shape[1])
gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0,
results['img'].shape[0])
results['gt_keypoints'] = gt_keypoints
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
"""Apply center crop on results. """Apply center crop on results.
...@@ -508,7 +560,7 @@ class CenterCrop(BaseTransform): ...@@ -508,7 +560,7 @@ class CenterCrop(BaseTransform):
dict: Results with CenterCropped image and semantic segmentation dict: Results with CenterCropped image and semantic segmentation
map. map.
""" """
crop_height, crop_width = self.crop_size[0], self.crop_size[1] crop_width, crop_height = self.crop_size[0], self.crop_size[1]
assert 'img' in results, '`img` is not found in results' assert 'img' in results, '`img` is not found in results'
img = results['img'] img = results['img']
...@@ -543,6 +595,10 @@ class CenterCrop(BaseTransform): ...@@ -543,6 +595,10 @@ class CenterCrop(BaseTransform):
self._crop_img(results, bboxes) self._crop_img(results, bboxes)
# crop the gt_semantic_seg # crop the gt_semantic_seg
self._crop_seg_map(results, bboxes) self._crop_seg_map(results, bboxes)
# crop the bounding box
self._crop_bboxes(results, bboxes)
# crop the keypoints
self._crop_keypoints(results, bboxes)
return results return results
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -550,6 +606,7 @@ class CenterCrop(BaseTransform): ...@@ -550,6 +606,7 @@ class CenterCrop(BaseTransform):
repr_str += f', crop_size = {self.crop_size}' repr_str += f', crop_size = {self.crop_size}'
repr_str += f', pad_val = {self.pad_val}' repr_str += f', pad_val = {self.pad_val}'
repr_str += f', pad_mode = {self.pad_mode}' repr_str += f', pad_mode = {self.pad_mode}'
repr_str += f',clip_object_border = {self.clip_object_border}'
return repr_str return repr_str
......
...@@ -248,6 +248,10 @@ class TestCenterCrop: ...@@ -248,6 +248,10 @@ class TestCenterCrop:
def reset_results(results, original_img, gt_semantic_map): def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img) results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map) results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
results['gt_bboxes'] = np.array([[0, 0, 210, 160],
[200, 150, 400, 300]])
results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]],
[[300, 225, 1]]])
return results return results
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -293,6 +297,12 @@ class TestCenterCrop: ...@@ -293,6 +297,12 @@ class TestCenterCrop:
assert ( assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262, results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all() 88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when size is tuple # test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224)) transform = dict(type='CenterCrop', crop_size=(224, 224))
...@@ -306,9 +316,15 @@ class TestCenterCrop: ...@@ -306,9 +316,15 @@ class TestCenterCrop:
assert ( assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262, results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all() 88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when crop_height != crop_width # test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(256, 224)) transform = dict(type='CenterCrop', crop_size=(224, 256))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -319,10 +335,16 @@ class TestCenterCrop: ...@@ -319,10 +335,16 @@ class TestCenterCrop:
assert ( assert (
results['gt_semantic_seg'] == self.gt_semantic_map[22:278, results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
88:312]).all() 88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 138], [112, 128, 224,
256]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 28, 1]], [[112, 128, 1]], [[212, 203, 1]]])).all()
# test CenterCrop when crop_size is equal to img.shape # test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape img_height, img_width, _ = self.original_img.shape
transform = dict(type='CenterCrop', crop_size=(img_height, img_width)) transform = dict(type='CenterCrop', crop_size=(img_width, img_height))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -331,10 +353,16 @@ class TestCenterCrop: ...@@ -331,10 +353,16 @@ class TestCenterCrop:
assert results['width'] == 400 assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all() assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
# test CenterCrop when crop_size is larger than img.shape # test CenterCrop when crop_size is larger than img.shape
transform = dict( transform = dict(
type='CenterCrop', crop_size=(img_height * 2, img_width * 2)) type='CenterCrop', crop_size=(img_width * 2, img_height * 2))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -343,11 +371,17 @@ class TestCenterCrop: ...@@ -343,11 +371,17 @@ class TestCenterCrop:
assert results['width'] == 400 assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all() assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
# test with padding # test with padding
transform = dict( transform = dict(
type='CenterCrop', type='CenterCrop',
crop_size=(img_height * 2, img_width // 2), crop_size=(img_width // 2, img_height * 2),
pad_mode='constant', pad_mode='constant',
pad_val=12) pad_val=12)
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
...@@ -359,10 +393,16 @@ class TestCenterCrop: ...@@ -359,10 +393,16 @@ class TestCenterCrop:
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all() assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all() assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
transform = dict( transform = dict(
type='CenterCrop', type='CenterCrop',
crop_size=(img_height * 2, img_width // 2), crop_size=(img_width // 2, img_height * 2),
pad_mode='constant', pad_mode='constant',
pad_val=dict(img=13, seg=33)) pad_val=dict(img=13, seg=33))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
...@@ -373,10 +413,16 @@ class TestCenterCrop: ...@@ -373,10 +413,16 @@ class TestCenterCrop:
assert results['width'] == 200 assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all() assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all() assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
# test CenterCrop when crop_width is smaller than img_width # test CenterCrop when crop_width is smaller than img_width
transform = dict( transform = dict(
type='CenterCrop', crop_size=(img_height, img_width // 2)) type='CenterCrop', crop_size=(img_width // 2, img_height))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -387,10 +433,16 @@ class TestCenterCrop: ...@@ -387,10 +433,16 @@ class TestCenterCrop:
assert ( assert (
results['gt_semantic_seg'] == self.gt_semantic_map[:, results['gt_semantic_seg'] == self.gt_semantic_map[:,
100:300]).all() 100:300]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
# test CenterCrop when crop_height is smaller than img_height # test CenterCrop when crop_height is smaller than img_height
transform = dict( transform = dict(
type='CenterCrop', crop_size=(img_height // 2, img_width)) type='CenterCrop', crop_size=(img_width, img_height // 2))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -400,6 +452,12 @@ class TestCenterCrop: ...@@ -400,6 +452,12 @@ class TestCenterCrop:
assert (results['img'] == self.original_img[75:225, ...]).all() assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225, assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
...]).all() ...]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 85], [200, 75, 400,
150]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 0, 1]], [[200, 75, 1]], [[300, 150, 1]]])).all()
@pytest.mark.skipif( @pytest.mark.skipif(
condition=torch is None, reason='No torch in current env') condition=torch is None, reason='No torch in current env')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment