Commit b7525fae authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Fix typos, init funcs, cacheable method of part3 of data transforms (#1784)



* fix typos and move args into cfg

* update docstring

* fix as comment

* fix lint

* update transforms as discussed before

* linting

* fix as comment

* fix lint

* fix lint and also update according to PackInput

* remove precommit change

* rename cacheable method
Co-authored-by: default avatarliukuikun <641417025@qq.com>
parent e7592a70
......@@ -2,7 +2,7 @@
from .builder import TRANSFORMS
from .loading import LoadAnnotation, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomChoiceResize, RandomFlip, RandomGrayscale,
RandomResize, Resize)
from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
......@@ -12,7 +12,7 @@ except ImportError:
__all__ = [
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
else:
......@@ -22,6 +22,6 @@ else:
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
]
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
......@@ -9,7 +9,7 @@ import mmcv
from mmcv.image.geometric import _scale_size
from .base import BaseTransform
from .builder import TRANSFORMS
from .utils import cacheable_method
from .utils import cache_randomness
from .wrappers import Compose
Number = Union[int, float]
......@@ -90,14 +90,14 @@ class Resize(BaseTransform):
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_keypoints (optional)
Modified Keys:
- img
- gt_bboxes
- gt_semantic_seg
- gt_seg_map
- gt_keypoints
- height
- width
......@@ -205,20 +205,20 @@ class Resize(BaseTransform):
def _resize_seg(self, results: dict) -> None:
"""Resize semantic segmentation map with ``results['scale']``."""
if results.get('gt_semantic_seg', None) is not None:
if results.get('gt_seg_map', None) is not None:
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results['gt_semantic_seg'],
results['gt_seg_map'],
results['scale'],
interpolation=self.interpolation,
backend=self.backend)
else:
gt_seg = mmcv.imresize(
results['gt_semantic_seg'],
results['gt_seg_map'],
results['scale'],
interpolation=self.interpolation,
backend=self.backend)
results['gt_semantic_seg'] = gt_seg
results['gt_seg_map'] = gt_seg
def _resize_keypoints(self, results: dict) -> None:
"""Resize keypoints with ``results['scale_factor']``."""
......@@ -241,7 +241,7 @@ class Resize(BaseTransform):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict.
"""
......@@ -279,12 +279,13 @@ class Pad(BaseTransform):
Required Keys:
- img
- gt_semantic_seg (optional)
- gt_bboxes (optional)
- gt_seg_map (optional)
Modified Keys:
- img
- gt_semantic_seg
- gt_seg_map
- height
- width
......@@ -387,15 +388,13 @@ class Pad(BaseTransform):
def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to
``results['pad_shape']``."""
if results.get('gt_semantic_seg', None) is not None:
if results.get('gt_seg_map', None) is not None:
pad_val = self.pad_val.get('seg', 255)
if isinstance(pad_val,
int) and results['gt_semantic_seg'].ndim == 3:
pad_val = tuple([
pad_val for _ in range(results['gt_semantic_seg'].shape[2])
])
results['gt_semantic_seg'] = mmcv.impad(
results['gt_semantic_seg'],
if isinstance(pad_val, int) and results['gt_seg_map'].ndim == 3:
pad_val = tuple(
[pad_val for _ in range(results['gt_seg_map'].shape[2])])
results['gt_seg_map'] = mmcv.impad(
results['gt_seg_map'],
shape=results['pad_shape'][:2],
pad_val=pad_val,
padding_mode=self.padding_mode)
......@@ -426,13 +425,13 @@ class Pad(BaseTransform):
@TRANSFORMS.register_module()
class CenterCrop(BaseTransform):
"""Crop the center of the image, segmentation masks, bounding boxes and key
points. If the crop area exceeds the original image and ``pad_mode`` is not
None, the original image will be padded before cropping.
points. If the crop area exceeds the original image and ``auto_pad`` is
True, the original image will be padded before cropping.
Required Keys:
- img
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
......@@ -441,7 +440,7 @@ class CenterCrop(BaseTransform):
- img
- height
- width
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
......@@ -454,15 +453,10 @@ class CenterCrop(BaseTransform):
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (w, h). If set to an integer, then cropping
width and height are equal to this integer.
pad_val (Union[Number, Dict[str, Number]]): A dict for
padding value. To specify how to set this argument, please see
the docstring of class ``Pad``. Defaults to
``dict(img=0, seg=255)``.
pad_mode (str, optional): Type of padding. Should be: 'constant',
'edge', 'reflect' or 'symmetric'. For details, please see the
docstring of class ``Pad``. Defaults to 'constant'.
pad_cfg (str): Base config for padding. Defaults to
``dict(type='Pad')``.
auto_pad (bool): Whether to pad the image if it's smaller than the
``crop_size``. Defaults to False.
pad_cfg (dict): Base config for padding. Refer to ``mmcv.Pad`` for
detail. Defaults to ``dict(type='Pad')``.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the
gt bboxes are allowed to cross the border of images. Therefore,
......@@ -470,14 +464,11 @@ class CenterCrop(BaseTransform):
Defaults to True.
"""
def __init__(
self,
crop_size: Union[int, Tuple[int, int]],
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
pad_mode: Optional[str] = None,
pad_cfg: dict = dict(type='Pad'),
clip_object_border: bool = True,
) -> None: # flake8: noqa
def __init__(self,
crop_size: Union[int, Tuple[int, int]],
auto_pad: bool = False,
pad_cfg: dict = dict(type='Pad'),
clip_object_border: bool = True) -> None:
super().__init__()
assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2
......@@ -488,9 +479,15 @@ class CenterCrop(BaseTransform):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.pad_val = pad_val
self.pad_mode = pad_mode
self.pad_cfg = pad_cfg
self.auto_pad = auto_pad
self.pad_cfg = pad_cfg.copy()
# size will be overwritten
if 'size' in self.pad_cfg and auto_pad:
warnings.warn('``size`` is set in ``pad_cfg``,'
'however this argument will be overwritten'
' according to crop size and image size')
self.clip_object_border = clip_object_border
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
......@@ -515,9 +512,9 @@ class CenterCrop(BaseTransform):
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if results.get('gt_semantic_seg', None) is not None:
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
results['gt_semantic_seg'] = img
if results.get('gt_seg_map', None) is not None:
img = mmcv.imcrop(results['gt_seg_map'], bboxes=bboxes)
results['gt_seg_map'] = img
def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None:
"""Update bounding boxes according to CenterCrop.
......@@ -587,17 +584,13 @@ class CenterCrop(BaseTransform):
img_height, img_width = img.shape[:2]
if crop_height > img_height or crop_width > img_width:
if self.pad_mode is not None:
if self.auto_pad:
# pad the area
img_height = max(img_height, crop_height)
img_width = max(img_width, crop_width)
pad_size = (img_width, img_height)
_pad_cfg = self.pad_cfg.copy()
_pad_cfg.update(
dict(
size=pad_size,
pad_val=self.pad_val,
padding_mode=self.pad_mode))
_pad_cfg.update(dict(size=pad_size))
pad_transform = TRANSFORMS.build(_pad_cfg)
results = pad_transform(results)
else:
......@@ -612,7 +605,7 @@ class CenterCrop(BaseTransform):
# crop the image
self._crop_img(results, bboxes)
# crop the gt_semantic_seg
# crop the gt_seg_map
self._crop_seg_map(results, bboxes)
# crop the bounding box
self._crop_bboxes(results, bboxes)
......@@ -623,8 +616,8 @@ class CenterCrop(BaseTransform):
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', crop_size = {self.crop_size}'
repr_str += f', pad_val = {self.pad_val}'
repr_str += f', pad_mode = {self.pad_mode}'
repr_str += f', auto_pad={self.auto_pad}'
repr_str += f', pad_cfg={self.pad_cfg}'
repr_str += f',clip_object_border = {self.clip_object_border}'
return repr_str
......@@ -649,10 +642,12 @@ class RandomGrayscale(BaseTransform):
Args:
prob (float): Probability that image should be converted to
grayscale. Defaults to 0.1.
keep_channel (bool): Whether keep channel number the same as
keep_channels (bool): Whether keep channel number the same as
input. Defaults to False.
channel_weights (tuple): Channel weights to compute gray
image. Defaults to (1., 1., 1.).
channel_weights (tuple): The grayscale weights of each channel,
and the weights will be normalized. For example, (1, 2, 1)
will be normalized as (0.25, 0.5, 0.25). Defaults to
(1., 1., 1.).
color_format (str): Color format set to be any of 'bgr',
'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr'
format no matter whether it is grayscaled. Defaults to 'bgr'.
......@@ -660,18 +655,22 @@ class RandomGrayscale(BaseTransform):
def __init__(self,
prob: float = 0.1,
keep_channel: bool = False,
keep_channels: bool = False,
channel_weights: Sequence[float] = (1., 1., 1.),
color_format: str = 'bgr') -> None:
super().__init__()
assert 0. <= prob <= 1., ('The range of ``prob`` value is [0., 1.],' +
f' but got {prob} instead')
self.prob = prob
self.keep_channel = keep_channel
self.keep_channels = keep_channels
self.channel_weights = channel_weights
assert color_format in ['bgr', 'rgb', 'hsv']
self.color_format = color_format
@cache_randomness
def _random_prob(self):
return random.random()
def transform(self, results: dict) -> dict:
"""Apply random grayscale on results.
......@@ -687,7 +686,7 @@ class RandomGrayscale(BaseTransform):
img = mmcv.hsv2bgr(img)
img = img[..., None] if img.ndim == 2 else img
num_output_channels = img.shape[2]
if random.random() < self.prob:
if self._random_prob() < self.prob:
if num_output_channels > 1:
assert num_output_channels == len(
self.channel_weights
......@@ -697,7 +696,7 @@ class RandomGrayscale(BaseTransform):
normalized_weights = (
np.array(self.channel_weights) / sum(self.channel_weights))
img = (normalized_weights * img).sum(axis=2)
if self.keep_channel:
if self.keep_channels:
img = img[:, :, None]
results['img'] = np.dstack(
[img for _ in range(num_output_channels)])
......@@ -710,7 +709,7 @@ class RandomGrayscale(BaseTransform):
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', prob = {self.prob}'
repr_str += f', keep_channel = {self.keep_channel}'
repr_str += f', keep_channels = {self.keep_channels}'
repr_str += f', channel_weights = {self.channel_weights}'
repr_str += f', color_format = {self.color_format}'
return repr_str
......@@ -753,13 +752,12 @@ class MultiScaleFlipAug(BaseTransform):
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
inputs=[...],
data_samples=[...]
)
Where the length of ``inputs`` and ``data_samples`` are both 4.
Required Keys:
- Depending on the requirements of the ``transforms`` parameter.
......@@ -774,7 +772,7 @@ class MultiScaleFlipAug(BaseTransform):
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
flip (bool): Whether apply flip augmentation. Defaults to False.
allow_flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
flip_direction is a list, multiple flip augmentations will be
......@@ -791,7 +789,7 @@ class MultiScaleFlipAug(BaseTransform):
transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None,
flip: bool = False,
allow_flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal',
resize_cfg: dict = dict(type='Resize', keep_ratio=True),
flip_cfg: dict = dict(type='RandomFlip')
......@@ -813,14 +811,14 @@ class MultiScaleFlipAug(BaseTransform):
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'
self.flip = flip
self.allow_flip = allow_flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
if not self.allow_flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
self.resize_cfg = resize_cfg
self.resize_cfg = resize_cfg.copy()
self.flip_cfg = flip_cfg
def transform(self, results: dict) -> Tuple[List, List]:
......@@ -834,10 +832,10 @@ class MultiScaleFlipAug(BaseTransform):
into a list.
"""
aug_data = []
input_data = []
data_samples = []
inputs = []
flip_args = [(False, '')]
if self.flip:
if self.allow_flip:
flip_args += [(True, direction)
for direction in self.flip_direction]
for scale in self.img_scale:
......@@ -857,23 +855,23 @@ class MultiScaleFlipAug(BaseTransform):
resize_flip = Compose(_resize_flip)
_results = results.copy()
_results = resize_flip(_results)
input_image, data_sample = self.transforms(_results)
packed_results = self.transforms(_results)
input_data.append(input_image)
aug_data.append(data_sample)
return input_data, aug_data
inputs.append(packed_results['inputs'])
data_samples.append(packed_results['data_sample'])
return dict(inputs=inputs, data_sample=data_samples)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}'
repr_str += f', img_scale={self.img_scale}'
repr_str += f', flip={self.flip}'
repr_str += f', allow_flip={self.allow_flip}'
repr_str += f', flip_direction={self.flip_direction}'
return repr_str
@TRANSFORMS.register_module()
class RandomMultiscaleResize(BaseTransform):
class RandomChoiceResize(BaseTransform):
"""Resize images & bbox & mask from a list of multiple scales.
This transform resizes the input image to some scale. Bboxes and masks are
......@@ -891,7 +889,7 @@ class RandomMultiscaleResize(BaseTransform):
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_keypoints (optional)
Modified Keys:
......@@ -900,7 +898,7 @@ class RandomMultiscaleResize(BaseTransform):
- height
- width
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_keypoints (optional)
Added Keys:
......@@ -913,24 +911,14 @@ class RandomMultiscaleResize(BaseTransform):
Args:
scales (Union[list, Tuple]): Images scales for resizing.
keep_ratio (bool): Whether to keep the aspect ratio when
resizing the image. Defaults to False.
clip_object_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and
'pillow'. These two backends generates slightly different results.
Defaults to 'cv2'.
interpolation (str): The mode of interpolation, support
"bilinear", "bicubic", "nearest". Defaults to "bilinear".
resize_cfg (dict): Base config for resizing. Refer to
``mmcv.Resize`` for detail. Defaults to
``dict(type='Resize')``.
"""
def __init__(
self,
scales: Union[list, Tuple],
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation: str = 'bilinear',
scales: Sequence[Union[int, Tuple]],
resize_cfg: dict = dict(type='Resize')
) -> None:
super().__init__()
......@@ -939,19 +927,11 @@ class RandomMultiscaleResize(BaseTransform):
else:
self.scales = [scales]
assert mmcv.is_list_of(self.scales, tuple)
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
self.backend = backend
self.interpolation = interpolation
self.resize_cfg = resize_cfg
@staticmethod
def random_select(scales: List[Tuple]) -> Tuple[tuple, int]:
"""Randomly select an img_scale from given candidates.
Args:
scales (list[tuple]): Images scales for selection.
@cache_randomness
def _random_select(self) -> Tuple[int, int]:
"""Randomly select an scale from given candidates.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
......@@ -959,9 +939,9 @@ class RandomMultiscaleResize(BaseTransform):
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(scales, tuple)
scale_idx = np.random.randint(len(scales))
scale = scales[scale_idx]
assert mmcv.is_list_of(self.scales, tuple)
scale_idx = np.random.randint(len(self.scales))
scale = self.scales[scale_idx]
return scale, scale_idx
def transform(self, results: dict) -> dict:
......@@ -971,20 +951,14 @@ class RandomMultiscaleResize(BaseTransform):
results (dict): Result dict contains the data to transform.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg',
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict.
"""
target_scale, scale_idx = self.random_select(self.scales)
target_scale, scale_idx = self._random_select()
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(
dict(
scale=target_scale,
keep_ratio=self.keep_ratio,
clip_object_border=self.clip_object_border,
backend=self.backend,
interpolation=self.interpolation))
_resize_cfg.update(dict(scale=target_scale))
resize_transform = TRANSFORMS.build(_resize_cfg)
results = resize_transform(results)
results['scale_idx'] = scale_idx
......@@ -993,17 +967,14 @@ class RandomMultiscaleResize(BaseTransform):
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f', scales={self.scales}'
repr_str += f', keep_ratio={self.keep_ratio}'
repr_str += f', clip_object_border={self.clip_object_border}'
repr_str += f', backend={self.backend}'
repr_str += f', interpolation={self.interpolation}'
repr_str += f', resize_cfg={self.resize_cfg}'
return repr_str
@TRANSFORMS.register_module()
class RandomFlip(BaseTransform):
"""Flip the image & bbox & keypoints & segmentation map. Added or Updated
keys: flip, flip_direction, img, gt_bboxes, gt_semantic_seg, and
keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and
gt_keypoints. There are 3 flip modes:
- ``prob`` is float, ``direction`` is string: the image will be
......@@ -1026,13 +997,13 @@ class RandomFlip(BaseTransform):
Required Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_keypoints (optional)
Modified Keys:
- img
- gt_bboxes (optional)
- gt_semantic_seg (optional)
- gt_seg_map (optional)
- gt_keypoints (optional)
Added Keys:
......@@ -1139,7 +1110,7 @@ class RandomFlip(BaseTransform):
flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped
@cacheable_method
@cache_randomness
def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`"""
if isinstance(self.direction,
......@@ -1207,7 +1178,7 @@ class RandomFlip(BaseTransform):
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'img', 'gt_bboxes', 'gt_semantic_seg',
dict: Flipped results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'flip', and 'flip_direction' keys are
updated in result dict.
"""
......@@ -1246,14 +1217,14 @@ class RandomResize(BaseTransform):
- img
- gt_bboxes
- gt_semantic_seg
- gt_seg_map
- gt_keypoints
Modified Keys:
- img
- gt_bboxes
- gt_semantic_seg
- gt_seg_map
- gt_keypoints
Added Keys:
......@@ -1335,7 +1306,7 @@ class RandomResize(BaseTransform):
scale = int(scale[0] * ratio), int(scale[1] * ratio)
return scale
@cacheable_method
@cache_randomness
def _random_scale(self) -> tuple:
"""Private function to randomly sample an scale according to the type
of ``scale``.
......
......@@ -56,7 +56,7 @@ class TestResize:
def test_resize(self):
data_info = dict(
img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)),
gt_seg_map=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]]))
......@@ -100,7 +100,7 @@ class TestResize:
results = transform(copy.deepcopy(data_info))
assert (results['gt_bboxes'] == np.array([[0, 0, 168, 224]])).all()
assert (results['gt_keypoints'] == np.array([[[30, 100, 1]]])).all()
assert results['gt_semantic_seg'].shape[:2] == (2666, 1200)
assert results['gt_seg_map'].shape[:2] == (2666, 1200)
# test clip_object_border = False
data_info = dict(
......@@ -143,39 +143,39 @@ class TestPad:
data_info = dict(
img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)),
gt_seg_map=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]]))
# test pad img / gt_semantic_seg with size
# test pad img / gt_seg_map with size
trans = Pad(size=(1200, 2000))
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 1200)
assert results['gt_semantic_seg'].shape[:2] == (2000, 1200)
assert results['gt_seg_map'].shape[:2] == (2000, 1200)
# test pad img/gt_semantic_seg with size_divisor
# test pad img/gt_seg_map with size_divisor
trans = Pad(size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 803)
assert results['gt_semantic_seg'].shape[:2] == (1342, 803)
assert results['gt_seg_map'].shape[:2] == (1342, 803)
# test pad img/gt_semantic_seg with pad_to_square
# test pad img/gt_seg_map with pad_to_square
trans = Pad(pad_to_square=True)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1333, 1333)
assert results['gt_semantic_seg'].shape[:2] == (1333, 1333)
assert results['gt_seg_map'].shape[:2] == (1333, 1333)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor
# test pad img/gt_seg_map with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342)
assert results['gt_seg_map'].shape[:2] == (1342, 1342)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor
# test pad img/gt_seg_map with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342)
assert results['gt_seg_map'].shape[:2] == (1342, 1342)
# test padding_mode
new_img = np.ones((1333, 800, 3))
......@@ -191,13 +191,12 @@ class TestPad:
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000, :] == 10).all()
assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 10).all()
trans = Pad(size=(2000, 2000), pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000,
800:2000, :] == 255).all()
assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 255).all()
# test rgb image, pad_to_square=True
trans = Pad(
......@@ -205,29 +204,28 @@ class TestPad:
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 10).all()
assert (results['gt_seg_map'][:, 800:1333, :] == 10).all()
trans = Pad(pad_to_square=True, pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 255).all()
assert (results['gt_seg_map'][:, 800:1333, :] == 255).all()
# test pad_val is int
# test rgb image
trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000,
800:2000, :] == 255).all()
assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 255).all()
# test gray image
new_img = np.random.random((1333, 800))
data_info['img'] = new_img
new_semantic_seg = np.random.random((1333, 800))
data_info['gt_semantic_seg'] = new_semantic_seg
data_info['gt_seg_map'] = new_semantic_seg
trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000] == 255).all()
assert (results['gt_seg_map'][1333:2000, 800:2000] == 255).all()
def test_repr(self):
trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge')
......@@ -249,7 +247,7 @@ class TestCenterCrop:
@staticmethod
def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
results['gt_seg_map'] = copy.deepcopy(gt_semantic_map)
results['gt_bboxes'] = np.array([[0, 0, 210, 160],
[200, 150, 400, 300]])
results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]],
......@@ -296,9 +294,8 @@ class TestCenterCrop:
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
......@@ -315,9 +312,8 @@ class TestCenterCrop:
assert results['height'] == 224
assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
......@@ -334,9 +330,8 @@ class TestCenterCrop:
assert results['height'] == 256
assert results['width'] == 224
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
88:312]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[22:278,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 138], [112, 128, 224,
256]])).all()
......@@ -354,7 +349,7 @@ class TestCenterCrop:
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
......@@ -372,7 +367,7 @@ class TestCenterCrop:
assert results['height'] == 300
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
......@@ -384,17 +379,17 @@ class TestCenterCrop:
transform = dict(
type='CenterCrop',
crop_size=(img_width // 2, img_height * 2),
pad_mode='constant',
pad_val=12)
auto_pad=True,
pad_cfg=dict(type='Pad', padding_mode='constant', pad_val=12))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
results = center_crop_module(results)
assert results['height'] == 600
assert results['width'] == 200
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
assert results['img'].shape[:2] == results['gt_seg_map'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
assert (results['gt_seg_map'][300:600, 100:300] == 255).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
......@@ -405,8 +400,11 @@ class TestCenterCrop:
transform = dict(
type='CenterCrop',
crop_size=(img_width // 2, img_height * 2),
pad_mode='constant',
pad_val=dict(img=13, seg=33))
auto_pad=True,
pad_cfg=dict(
type='Pad',
padding_mode='constant',
pad_val=dict(img=13, seg=33)))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
......@@ -414,7 +412,7 @@ class TestCenterCrop:
assert results['height'] == 600
assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
assert (results['gt_seg_map'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
......@@ -432,9 +430,8 @@ class TestCenterCrop:
assert results['height'] == img_height
assert results['width'] == img_width // 2
assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[:,
100:300]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[:,
100:300]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
......@@ -452,8 +449,8 @@ class TestCenterCrop:
assert results['height'] == img_height // 2
assert results['width'] == img_width
assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
...]).all()
assert (results['gt_seg_map'] == self.gt_semantic_map[75:225,
...]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 85], [200, 75, 400,
150]])).all()
......@@ -479,7 +476,7 @@ class TestCenterCrop:
cropped_seg = center_crop_module(pil_seg)
cropped_seg = np.array(cropped_seg)
assert np.equal(results['img'], cropped_img).all()
assert np.equal(results['gt_semantic_seg'], cropped_seg).all()
assert np.equal(results['gt_seg_map'], cropped_seg).all()
class TestRandomGrayscale:
......@@ -494,7 +491,7 @@ class TestRandomGrayscale:
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
keep_channels=True)
random_gray_scale_module = TRANSFORMS.build(transform)
assert isinstance(repr(random_gray_scale_module), str)
......@@ -511,7 +508,7 @@ class TestRandomGrayscale:
type='RandomGrayscale',
prob=1.,
channel_weights=(0.299, 0.587, 0.114),
keep_channel=True)
keep_channels=True)
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
......@@ -541,14 +538,14 @@ class TestRandomGrayscale:
@TRANSFORMS.register_module()
class MockFormatBundle(BaseTransform):
class MockPackTaskInputs(BaseTransform):
def __init__(self) -> None:
super().__init__()
def transform(self, results):
data_sample = Mock()
return results['img'], data_sample
packed_results = dict(inputs=results['img'], data_sample=Mock())
return packed_results
class TestMultiScaleFlipAug:
......@@ -581,30 +578,28 @@ class TestMultiScaleFlipAug:
# test with empty transforms
transform = dict(
type='MultiScaleFlipAug',
transforms=[dict(type='MockFormatBundle')],
transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
packed_results = multi_scale_flip_aug_module(results)
assert len(packed_results['inputs']) == 12
# test with flip=False
# test with allow_flip=False
transform = dict(
type='MultiScaleFlipAug',
transforms=[dict(type='MockFormatBundle')],
transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False,
allow_flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 3
assert len(data_sample) == 3
packed_results = multi_scale_flip_aug_module(results)
assert len(packed_results['inputs']) == 3
# test with transforms
img_norm_cfg = dict(
......@@ -615,20 +610,19 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
dict(type='MockPackTaskInputs')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
packed_results = multi_scale_flip_aug_module(results)
assert len(packed_results['inputs']) == 12
# test with scale_factor
img_norm_cfg = dict(
......@@ -639,20 +633,19 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
dict(type='MockPackTaskInputs')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
scale_factor=[0.5, 1., 2.],
flip=True,
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
packed_results = multi_scale_flip_aug_module(results)
assert len(packed_results['inputs']) == 12
# test no resize
img_norm_cfg = dict(
......@@ -663,22 +656,21 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
dict(type='MockPackTaskInputs')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
flip=True,
allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 4
assert len(data_sample) == 4
packed_results = multi_scale_flip_aug_module(results)
assert len(packed_results['inputs']) == 4
class TestRandomMultiscaleResize:
class TestRandomChoiceResize:
@classmethod
def setup_class(cls):
......@@ -688,25 +680,25 @@ class TestRandomMultiscaleResize:
def reset_results(self, results):
results['img'] = copy.deepcopy(self.original_img)
results['gt_semantic_seg'] = copy.deepcopy(self.original_img)
results['gt_seg_map'] = copy.deepcopy(self.original_img)
def test_repr(self):
# test repr
transform = dict(
type='RandomMultiscaleResize', scales=[(1333, 800), (1333, 600)])
type='RandomChoiceResize', scales=[(1333, 800), (1333, 600)])
random_multiscale_resize = TRANSFORMS.build(transform)
assert isinstance(repr(random_multiscale_resize), str)
def test_error(self):
# test assertion if size is smaller than 0
with pytest.raises(AssertionError):
transform = dict(type='RandomMultiscaleResize', scales=[0.5, 1, 2])
transform = dict(type='RandomChoiceResize', scales=[0.5, 1, 2])
TRANSFORMS.build(transform)
def test_random_multiscale_resize(self):
results = dict()
# test with one scale
transform = dict(type='RandomMultiscaleResize', scales=[(1333, 800)])
transform = dict(type='RandomChoiceResize', scales=[(1333, 800)])
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
......@@ -714,7 +706,7 @@ class TestRandomMultiscaleResize:
# test with multi scales
_scale_choice = [(1333, 800), (1333, 600)]
transform = dict(type='RandomMultiscaleResize', scales=_scale_choice)
transform = dict(type='RandomChoiceResize', scales=_scale_choice)
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results = random_multiscale_resize(results)
......@@ -723,9 +715,9 @@ class TestRandomMultiscaleResize:
# test keep_ratio
transform = dict(
type='RandomMultiscaleResize',
type='RandomChoiceResize',
scales=[(900, 600)],
keep_ratio=True)
resize_cfg=dict(type='Resize', keep_ratio=True))
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
_input_ratio = results['img'].shape[0] / results['img'].shape[1]
......@@ -736,9 +728,9 @@ class TestRandomMultiscaleResize:
# test clip_object_border
gt_bboxes = [[200, 150, 600, 450]]
transform = dict(
type='RandomMultiscaleResize',
type='RandomChoiceResize',
scales=[(200, 150)],
clip_object_border=True)
resize_cfg=dict(type='Resize', clip_object_border=True))
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
......@@ -748,9 +740,9 @@ class TestRandomMultiscaleResize:
150]])).all()
transform = dict(
type='RandomMultiscaleResize',
type='RandomChoiceResize',
scales=[(200, 150)],
clip_object_border=False)
resize_cfg=dict(type='Resize', clip_object_border=False))
random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes)
......@@ -873,7 +865,7 @@ class TestRandomResize:
# keep ratio is True
results = {
'img': np.random.random((224, 224, 3)),
'gt_semantic_seg': np.random.random((224, 224, 3)),
'gt_seg_map': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 0, 112, 112]]),
'gt_keypoints': np.array([[[112, 112]]])
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment