Commit b7525fae authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Fix typos, init funcs, cacheable method of part3 of data transforms (#1784)



* fix typos and move args into cfg

* update docstring

* fix as comment

* fix lint

* update transforms as discussed before

* linting

* fix as comment

* fix lint

* fix lint and also update according to PackInput

* remove precommit change

* rename cacheable method
Co-authored-by: default avatarliukuikun <641417025@qq.com>
parent e7592a70
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .loading import LoadAnnotation, LoadImageFromFile from .loading import LoadAnnotation, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad, from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize, RandomChoiceResize, RandomFlip, RandomGrayscale,
RandomResize, Resize) RandomResize, Resize)
from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
...@@ -12,7 +12,7 @@ except ImportError: ...@@ -12,7 +12,7 @@ except ImportError:
__all__ = [ __all__ = [
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice', 'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop', 'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize' 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
] ]
else: else:
...@@ -22,6 +22,6 @@ else: ...@@ -22,6 +22,6 @@ else:
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice', 'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop', 'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize' 'MultiScaleFlipAug', 'RandomResize'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -9,7 +9,7 @@ import mmcv ...@@ -9,7 +9,7 @@ import mmcv
from mmcv.image.geometric import _scale_size from mmcv.image.geometric import _scale_size
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .utils import cacheable_method from .utils import cache_randomness
from .wrappers import Compose from .wrappers import Compose
Number = Union[int, float] Number = Union[int, float]
...@@ -90,14 +90,14 @@ class Resize(BaseTransform): ...@@ -90,14 +90,14 @@ class Resize(BaseTransform):
- img - img
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
Modified Keys: Modified Keys:
- img - img
- gt_bboxes - gt_bboxes
- gt_semantic_seg - gt_seg_map
- gt_keypoints - gt_keypoints
- height - height
- width - width
...@@ -205,20 +205,20 @@ class Resize(BaseTransform): ...@@ -205,20 +205,20 @@ class Resize(BaseTransform):
def _resize_seg(self, results: dict) -> None: def _resize_seg(self, results: dict) -> None:
"""Resize semantic segmentation map with ``results['scale']``.""" """Resize semantic segmentation map with ``results['scale']``."""
if results.get('gt_semantic_seg', None) is not None: if results.get('gt_seg_map', None) is not None:
if self.keep_ratio: if self.keep_ratio:
gt_seg = mmcv.imrescale( gt_seg = mmcv.imrescale(
results['gt_semantic_seg'], results['gt_seg_map'],
results['scale'], results['scale'],
interpolation=self.interpolation, interpolation=self.interpolation,
backend=self.backend) backend=self.backend)
else: else:
gt_seg = mmcv.imresize( gt_seg = mmcv.imresize(
results['gt_semantic_seg'], results['gt_seg_map'],
results['scale'], results['scale'],
interpolation=self.interpolation, interpolation=self.interpolation,
backend=self.backend) backend=self.backend)
results['gt_semantic_seg'] = gt_seg results['gt_seg_map'] = gt_seg
def _resize_keypoints(self, results: dict) -> None: def _resize_keypoints(self, results: dict) -> None:
"""Resize keypoints with ``results['scale_factor']``.""" """Resize keypoints with ``results['scale_factor']``."""
...@@ -241,7 +241,7 @@ class Resize(BaseTransform): ...@@ -241,7 +241,7 @@ class Resize(BaseTransform):
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg', dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width', 'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict. and 'keep_ratio' keys are updated in result dict.
""" """
...@@ -279,12 +279,13 @@ class Pad(BaseTransform): ...@@ -279,12 +279,13 @@ class Pad(BaseTransform):
Required Keys: Required Keys:
- img - img
- gt_semantic_seg (optional) - gt_bboxes (optional)
- gt_seg_map (optional)
Modified Keys: Modified Keys:
- img - img
- gt_semantic_seg - gt_seg_map
- height - height
- width - width
...@@ -387,15 +388,13 @@ class Pad(BaseTransform): ...@@ -387,15 +388,13 @@ class Pad(BaseTransform):
def _pad_seg(self, results: dict) -> None: def _pad_seg(self, results: dict) -> None:
"""Pad semantic segmentation map according to """Pad semantic segmentation map according to
``results['pad_shape']``.""" ``results['pad_shape']``."""
if results.get('gt_semantic_seg', None) is not None: if results.get('gt_seg_map', None) is not None:
pad_val = self.pad_val.get('seg', 255) pad_val = self.pad_val.get('seg', 255)
if isinstance(pad_val, if isinstance(pad_val, int) and results['gt_seg_map'].ndim == 3:
int) and results['gt_semantic_seg'].ndim == 3: pad_val = tuple(
pad_val = tuple([ [pad_val for _ in range(results['gt_seg_map'].shape[2])])
pad_val for _ in range(results['gt_semantic_seg'].shape[2]) results['gt_seg_map'] = mmcv.impad(
]) results['gt_seg_map'],
results['gt_semantic_seg'] = mmcv.impad(
results['gt_semantic_seg'],
shape=results['pad_shape'][:2], shape=results['pad_shape'][:2],
pad_val=pad_val, pad_val=pad_val,
padding_mode=self.padding_mode) padding_mode=self.padding_mode)
...@@ -426,13 +425,13 @@ class Pad(BaseTransform): ...@@ -426,13 +425,13 @@ class Pad(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class CenterCrop(BaseTransform): class CenterCrop(BaseTransform):
"""Crop the center of the image, segmentation masks, bounding boxes and key """Crop the center of the image, segmentation masks, bounding boxes and key
points. If the crop area exceeds the original image and ``pad_mode`` is not points. If the crop area exceeds the original image and ``auto_pad`` is
None, the original image will be padded before cropping. True, the original image will be padded before cropping.
Required Keys: Required Keys:
- img - img
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
...@@ -441,7 +440,7 @@ class CenterCrop(BaseTransform): ...@@ -441,7 +440,7 @@ class CenterCrop(BaseTransform):
- img - img
- height - height
- width - width
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
...@@ -454,15 +453,10 @@ class CenterCrop(BaseTransform): ...@@ -454,15 +453,10 @@ class CenterCrop(BaseTransform):
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (w, h). If set to an integer, then cropping with the format of (w, h). If set to an integer, then cropping
width and height are equal to this integer. width and height are equal to this integer.
pad_val (Union[Number, Dict[str, Number]]): A dict for auto_pad (bool): Whether to pad the image if it's smaller than the
padding value. To specify how to set this argument, please see ``crop_size``. Defaults to False.
the docstring of class ``Pad``. Defaults to pad_cfg (dict): Base config for padding. Refer to ``mmcv.Pad`` for
``dict(img=0, seg=255)``. detail. Defaults to ``dict(type='Pad')``.
pad_mode (str, optional): Type of padding. Should be: 'constant',
'edge', 'reflect' or 'symmetric'. For details, please see the
docstring of class ``Pad``. Defaults to 'constant'.
pad_cfg (str): Base config for padding. Defaults to
``dict(type='Pad')``.
clip_object_border (bool): Whether to clip the objects clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the outside the border of the image. In some dataset like MOT17, the
gt bboxes are allowed to cross the border of images. Therefore, gt bboxes are allowed to cross the border of images. Therefore,
...@@ -470,14 +464,11 @@ class CenterCrop(BaseTransform): ...@@ -470,14 +464,11 @@ class CenterCrop(BaseTransform):
Defaults to True. Defaults to True.
""" """
def __init__( def __init__(self,
self, crop_size: Union[int, Tuple[int, int]],
crop_size: Union[int, Tuple[int, int]], auto_pad: bool = False,
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255), pad_cfg: dict = dict(type='Pad'),
pad_mode: Optional[str] = None, clip_object_border: bool = True) -> None:
pad_cfg: dict = dict(type='Pad'),
clip_object_border: bool = True,
) -> None: # flake8: noqa
super().__init__() super().__init__()
assert isinstance(crop_size, int) or ( assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2 isinstance(crop_size, tuple) and len(crop_size) == 2
...@@ -488,9 +479,15 @@ class CenterCrop(BaseTransform): ...@@ -488,9 +479,15 @@ class CenterCrop(BaseTransform):
crop_size = (crop_size, crop_size) crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0 assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size self.crop_size = crop_size
self.pad_val = pad_val self.auto_pad = auto_pad
self.pad_mode = pad_mode
self.pad_cfg = pad_cfg self.pad_cfg = pad_cfg.copy()
# size will be overwritten
if 'size' in self.pad_cfg and auto_pad:
warnings.warn('``size`` is set in ``pad_cfg``,'
'however this argument will be overwritten'
' according to crop size and image size')
self.clip_object_border = clip_object_border self.clip_object_border = clip_object_border
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None: def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
...@@ -515,9 +512,9 @@ class CenterCrop(BaseTransform): ...@@ -515,9 +512,9 @@ class CenterCrop(BaseTransform):
results (dict): Result dict contains the data to transform. results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes. bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
""" """
if results.get('gt_semantic_seg', None) is not None: if results.get('gt_seg_map', None) is not None:
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes) img = mmcv.imcrop(results['gt_seg_map'], bboxes=bboxes)
results['gt_semantic_seg'] = img results['gt_seg_map'] = img
def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None: def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None:
"""Update bounding boxes according to CenterCrop. """Update bounding boxes according to CenterCrop.
...@@ -587,17 +584,13 @@ class CenterCrop(BaseTransform): ...@@ -587,17 +584,13 @@ class CenterCrop(BaseTransform):
img_height, img_width = img.shape[:2] img_height, img_width = img.shape[:2]
if crop_height > img_height or crop_width > img_width: if crop_height > img_height or crop_width > img_width:
if self.pad_mode is not None: if self.auto_pad:
# pad the area # pad the area
img_height = max(img_height, crop_height) img_height = max(img_height, crop_height)
img_width = max(img_width, crop_width) img_width = max(img_width, crop_width)
pad_size = (img_width, img_height) pad_size = (img_width, img_height)
_pad_cfg = self.pad_cfg.copy() _pad_cfg = self.pad_cfg.copy()
_pad_cfg.update( _pad_cfg.update(dict(size=pad_size))
dict(
size=pad_size,
pad_val=self.pad_val,
padding_mode=self.pad_mode))
pad_transform = TRANSFORMS.build(_pad_cfg) pad_transform = TRANSFORMS.build(_pad_cfg)
results = pad_transform(results) results = pad_transform(results)
else: else:
...@@ -612,7 +605,7 @@ class CenterCrop(BaseTransform): ...@@ -612,7 +605,7 @@ class CenterCrop(BaseTransform):
# crop the image # crop the image
self._crop_img(results, bboxes) self._crop_img(results, bboxes)
# crop the gt_semantic_seg # crop the gt_seg_map
self._crop_seg_map(results, bboxes) self._crop_seg_map(results, bboxes)
# crop the bounding box # crop the bounding box
self._crop_bboxes(results, bboxes) self._crop_bboxes(results, bboxes)
...@@ -623,8 +616,8 @@ class CenterCrop(BaseTransform): ...@@ -623,8 +616,8 @@ class CenterCrop(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', crop_size = {self.crop_size}' repr_str += f', crop_size = {self.crop_size}'
repr_str += f', pad_val = {self.pad_val}' repr_str += f', auto_pad={self.auto_pad}'
repr_str += f', pad_mode = {self.pad_mode}' repr_str += f', pad_cfg={self.pad_cfg}'
repr_str += f',clip_object_border = {self.clip_object_border}' repr_str += f',clip_object_border = {self.clip_object_border}'
return repr_str return repr_str
...@@ -649,10 +642,12 @@ class RandomGrayscale(BaseTransform): ...@@ -649,10 +642,12 @@ class RandomGrayscale(BaseTransform):
Args: Args:
prob (float): Probability that image should be converted to prob (float): Probability that image should be converted to
grayscale. Defaults to 0.1. grayscale. Defaults to 0.1.
keep_channel (bool): Whether keep channel number the same as keep_channels (bool): Whether keep channel number the same as
input. Defaults to False. input. Defaults to False.
channel_weights (tuple): Channel weights to compute gray channel_weights (tuple): The grayscale weights of each channel,
image. Defaults to (1., 1., 1.). and the weights will be normalized. For example, (1, 2, 1)
will be normalized as (0.25, 0.5, 0.25). Defaults to
(1., 1., 1.).
color_format (str): Color format set to be any of 'bgr', color_format (str): Color format set to be any of 'bgr',
'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr' 'rgb', 'hsv'. Note: 'hsv' image will be transformed into 'bgr'
format no matter whether it is grayscaled. Defaults to 'bgr'. format no matter whether it is grayscaled. Defaults to 'bgr'.
...@@ -660,18 +655,22 @@ class RandomGrayscale(BaseTransform): ...@@ -660,18 +655,22 @@ class RandomGrayscale(BaseTransform):
def __init__(self, def __init__(self,
prob: float = 0.1, prob: float = 0.1,
keep_channel: bool = False, keep_channels: bool = False,
channel_weights: Sequence[float] = (1., 1., 1.), channel_weights: Sequence[float] = (1., 1., 1.),
color_format: str = 'bgr') -> None: color_format: str = 'bgr') -> None:
super().__init__() super().__init__()
assert 0. <= prob <= 1., ('The range of ``prob`` value is [0., 1.],' + assert 0. <= prob <= 1., ('The range of ``prob`` value is [0., 1.],' +
f' but got {prob} instead') f' but got {prob} instead')
self.prob = prob self.prob = prob
self.keep_channel = keep_channel self.keep_channels = keep_channels
self.channel_weights = channel_weights self.channel_weights = channel_weights
assert color_format in ['bgr', 'rgb', 'hsv'] assert color_format in ['bgr', 'rgb', 'hsv']
self.color_format = color_format self.color_format = color_format
@cache_randomness
def _random_prob(self):
return random.random()
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
"""Apply random grayscale on results. """Apply random grayscale on results.
...@@ -687,7 +686,7 @@ class RandomGrayscale(BaseTransform): ...@@ -687,7 +686,7 @@ class RandomGrayscale(BaseTransform):
img = mmcv.hsv2bgr(img) img = mmcv.hsv2bgr(img)
img = img[..., None] if img.ndim == 2 else img img = img[..., None] if img.ndim == 2 else img
num_output_channels = img.shape[2] num_output_channels = img.shape[2]
if random.random() < self.prob: if self._random_prob() < self.prob:
if num_output_channels > 1: if num_output_channels > 1:
assert num_output_channels == len( assert num_output_channels == len(
self.channel_weights self.channel_weights
...@@ -697,7 +696,7 @@ class RandomGrayscale(BaseTransform): ...@@ -697,7 +696,7 @@ class RandomGrayscale(BaseTransform):
normalized_weights = ( normalized_weights = (
np.array(self.channel_weights) / sum(self.channel_weights)) np.array(self.channel_weights) / sum(self.channel_weights))
img = (normalized_weights * img).sum(axis=2) img = (normalized_weights * img).sum(axis=2)
if self.keep_channel: if self.keep_channels:
img = img[:, :, None] img = img[:, :, None]
results['img'] = np.dstack( results['img'] = np.dstack(
[img for _ in range(num_output_channels)]) [img for _ in range(num_output_channels)])
...@@ -710,7 +709,7 @@ class RandomGrayscale(BaseTransform): ...@@ -710,7 +709,7 @@ class RandomGrayscale(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', prob = {self.prob}' repr_str += f', prob = {self.prob}'
repr_str += f', keep_channel = {self.keep_channel}' repr_str += f', keep_channels = {self.keep_channels}'
repr_str += f', channel_weights = {self.channel_weights}' repr_str += f', channel_weights = {self.channel_weights}'
repr_str += f', color_format = {self.color_format}' repr_str += f', color_format = {self.color_format}'
return repr_str return repr_str
...@@ -753,13 +752,12 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -753,13 +752,12 @@ class MultiScaleFlipAug(BaseTransform):
.. code-block:: .. code-block::
dict( dict(
img=[...], inputs=[...],
img_shape=[...], data_samples=[...]
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
) )
Where the length of ``inputs`` and ``data_samples`` are both 4.
Required Keys: Required Keys:
- Depending on the requirements of the ``transforms`` parameter. - Depending on the requirements of the ``transforms`` parameter.
...@@ -774,7 +772,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -774,7 +772,7 @@ class MultiScaleFlipAug(BaseTransform):
img_scale (tuple | list[tuple] | None): Images scales for resizing. img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing. scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None. Defaults to None.
flip (bool): Whether apply flip augmentation. Defaults to False. allow_flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions, flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If options are "horizontal", "vertical" and "diagonal". If
flip_direction is a list, multiple flip augmentations will be flip_direction is a list, multiple flip augmentations will be
...@@ -791,7 +789,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -791,7 +789,7 @@ class MultiScaleFlipAug(BaseTransform):
transforms: List[dict], transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None, img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None, scale_factor: Optional[Union[float, List[float]]] = None,
flip: bool = False, allow_flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal', flip_direction: Union[str, List[str]] = 'horizontal',
resize_cfg: dict = dict(type='Resize', keep_ratio=True), resize_cfg: dict = dict(type='Resize', keep_ratio=True),
flip_cfg: dict = dict(type='RandomFlip') flip_cfg: dict = dict(type='RandomFlip')
...@@ -813,14 +811,14 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -813,14 +811,14 @@ class MultiScaleFlipAug(BaseTransform):
scale_factor, list) else [scale_factor] scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor' self.scale_key = 'scale_factor'
self.flip = flip self.allow_flip = allow_flip
self.flip_direction = flip_direction if isinstance( self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction] flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str) assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']: if not self.allow_flip and self.flip_direction != ['horizontal']:
warnings.warn( warnings.warn(
'flip_direction has no effect when flip is set to False') 'flip_direction has no effect when flip is set to False')
self.resize_cfg = resize_cfg self.resize_cfg = resize_cfg.copy()
self.flip_cfg = flip_cfg self.flip_cfg = flip_cfg
def transform(self, results: dict) -> Tuple[List, List]: def transform(self, results: dict) -> Tuple[List, List]:
...@@ -834,10 +832,10 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -834,10 +832,10 @@ class MultiScaleFlipAug(BaseTransform):
into a list. into a list.
""" """
aug_data = [] data_samples = []
input_data = [] inputs = []
flip_args = [(False, '')] flip_args = [(False, '')]
if self.flip: if self.allow_flip:
flip_args += [(True, direction) flip_args += [(True, direction)
for direction in self.flip_direction] for direction in self.flip_direction]
for scale in self.img_scale: for scale in self.img_scale:
...@@ -857,23 +855,23 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -857,23 +855,23 @@ class MultiScaleFlipAug(BaseTransform):
resize_flip = Compose(_resize_flip) resize_flip = Compose(_resize_flip)
_results = results.copy() _results = results.copy()
_results = resize_flip(_results) _results = resize_flip(_results)
input_image, data_sample = self.transforms(_results) packed_results = self.transforms(_results)
input_data.append(input_image) inputs.append(packed_results['inputs'])
aug_data.append(data_sample) data_samples.append(packed_results['data_sample'])
return input_data, aug_data return dict(inputs=inputs, data_sample=data_samples)
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}' repr_str += f', transforms={self.transforms}'
repr_str += f', img_scale={self.img_scale}' repr_str += f', img_scale={self.img_scale}'
repr_str += f', flip={self.flip}' repr_str += f', allow_flip={self.allow_flip}'
repr_str += f', flip_direction={self.flip_direction}' repr_str += f', flip_direction={self.flip_direction}'
return repr_str return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomMultiscaleResize(BaseTransform): class RandomChoiceResize(BaseTransform):
"""Resize images & bbox & mask from a list of multiple scales. """Resize images & bbox & mask from a list of multiple scales.
This transform resizes the input image to some scale. Bboxes and masks are This transform resizes the input image to some scale. Bboxes and masks are
...@@ -891,7 +889,7 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -891,7 +889,7 @@ class RandomMultiscaleResize(BaseTransform):
- img - img
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
Modified Keys: Modified Keys:
...@@ -900,7 +898,7 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -900,7 +898,7 @@ class RandomMultiscaleResize(BaseTransform):
- height - height
- width - width
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
Added Keys: Added Keys:
...@@ -913,24 +911,14 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -913,24 +911,14 @@ class RandomMultiscaleResize(BaseTransform):
Args: Args:
scales (Union[list, Tuple]): Images scales for resizing. scales (Union[list, Tuple]): Images scales for resizing.
keep_ratio (bool): Whether to keep the aspect ratio when resize_cfg (dict): Base config for resizing. Refer to
resizing the image. Defaults to False. ``mmcv.Resize`` for detail. Defaults to
clip_object_border (bool): Whether clip the objects outside ``dict(type='Resize')``.
the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and
'pillow'. These two backends generates slightly different results.
Defaults to 'cv2'.
interpolation (str): The mode of interpolation, support
"bilinear", "bicubic", "nearest". Defaults to "bilinear".
""" """
def __init__( def __init__(
self, self,
scales: Union[list, Tuple], scales: Sequence[Union[int, Tuple]],
keep_ratio: bool = False,
clip_object_border: bool = True,
backend: str = 'cv2',
interpolation: str = 'bilinear',
resize_cfg: dict = dict(type='Resize') resize_cfg: dict = dict(type='Resize')
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -939,19 +927,11 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -939,19 +927,11 @@ class RandomMultiscaleResize(BaseTransform):
else: else:
self.scales = [scales] self.scales = [scales]
assert mmcv.is_list_of(self.scales, tuple) assert mmcv.is_list_of(self.scales, tuple)
self.keep_ratio = keep_ratio
self.clip_object_border = clip_object_border
self.backend = backend
self.interpolation = interpolation
self.resize_cfg = resize_cfg self.resize_cfg = resize_cfg
@staticmethod @cache_randomness
def random_select(scales: List[Tuple]) -> Tuple[tuple, int]: def _random_select(self) -> Tuple[int, int]:
"""Randomly select an img_scale from given candidates. """Randomly select an scale from given candidates.
Args:
scales (list[tuple]): Images scales for selection.
Returns: Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
...@@ -959,9 +939,9 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -959,9 +939,9 @@ class RandomMultiscaleResize(BaseTransform):
``scale_idx`` is the selected index in the given candidates. ``scale_idx`` is the selected index in the given candidates.
""" """
assert mmcv.is_list_of(scales, tuple) assert mmcv.is_list_of(self.scales, tuple)
scale_idx = np.random.randint(len(scales)) scale_idx = np.random.randint(len(self.scales))
scale = scales[scale_idx] scale = self.scales[scale_idx]
return scale, scale_idx return scale, scale_idx
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
...@@ -971,20 +951,14 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -971,20 +951,14 @@ class RandomMultiscaleResize(BaseTransform):
results (dict): Result dict contains the data to transform. results (dict): Result dict contains the data to transform.
Returns: Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_semantic_seg', dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'height', 'width', 'gt_keypoints', 'scale', 'scale_factor', 'height', 'width',
and 'keep_ratio' keys are updated in result dict. and 'keep_ratio' keys are updated in result dict.
""" """
target_scale, scale_idx = self.random_select(self.scales) target_scale, scale_idx = self._random_select()
_resize_cfg = self.resize_cfg.copy() _resize_cfg = self.resize_cfg.copy()
_resize_cfg.update( _resize_cfg.update(dict(scale=target_scale))
dict(
scale=target_scale,
keep_ratio=self.keep_ratio,
clip_object_border=self.clip_object_border,
backend=self.backend,
interpolation=self.interpolation))
resize_transform = TRANSFORMS.build(_resize_cfg) resize_transform = TRANSFORMS.build(_resize_cfg)
results = resize_transform(results) results = resize_transform(results)
results['scale_idx'] = scale_idx results['scale_idx'] = scale_idx
...@@ -993,17 +967,14 @@ class RandomMultiscaleResize(BaseTransform): ...@@ -993,17 +967,14 @@ class RandomMultiscaleResize(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', scales={self.scales}' repr_str += f', scales={self.scales}'
repr_str += f', keep_ratio={self.keep_ratio}' repr_str += f', resize_cfg={self.resize_cfg}'
repr_str += f', clip_object_border={self.clip_object_border}'
repr_str += f', backend={self.backend}'
repr_str += f', interpolation={self.interpolation}'
return repr_str return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomFlip(BaseTransform): class RandomFlip(BaseTransform):
"""Flip the image & bbox & keypoints & segmentation map. Added or Updated """Flip the image & bbox & keypoints & segmentation map. Added or Updated
keys: flip, flip_direction, img, gt_bboxes, gt_semantic_seg, and keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and
gt_keypoints. There are 3 flip modes: gt_keypoints. There are 3 flip modes:
- ``prob`` is float, ``direction`` is string: the image will be - ``prob`` is float, ``direction`` is string: the image will be
...@@ -1026,13 +997,13 @@ class RandomFlip(BaseTransform): ...@@ -1026,13 +997,13 @@ class RandomFlip(BaseTransform):
Required Keys: Required Keys:
- img - img
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
Modified Keys: Modified Keys:
- img - img
- gt_bboxes (optional) - gt_bboxes (optional)
- gt_semantic_seg (optional) - gt_seg_map (optional)
- gt_keypoints (optional) - gt_keypoints (optional)
Added Keys: Added Keys:
...@@ -1139,7 +1110,7 @@ class RandomFlip(BaseTransform): ...@@ -1139,7 +1110,7 @@ class RandomFlip(BaseTransform):
flipped = np.concatenate([keypoints, meta_info], axis=-1) flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped return flipped
@cacheable_method @cache_randomness
def _choose_direction(self) -> str: def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`""" """Choose the flip direction according to `prob` and `direction`"""
if isinstance(self.direction, if isinstance(self.direction,
...@@ -1207,7 +1178,7 @@ class RandomFlip(BaseTransform): ...@@ -1207,7 +1178,7 @@ class RandomFlip(BaseTransform):
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Flipped results, 'img', 'gt_bboxes', 'gt_semantic_seg', dict: Flipped results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'flip', and 'flip_direction' keys are 'gt_keypoints', 'flip', and 'flip_direction' keys are
updated in result dict. updated in result dict.
""" """
...@@ -1246,14 +1217,14 @@ class RandomResize(BaseTransform): ...@@ -1246,14 +1217,14 @@ class RandomResize(BaseTransform):
- img - img
- gt_bboxes - gt_bboxes
- gt_semantic_seg - gt_seg_map
- gt_keypoints - gt_keypoints
Modified Keys: Modified Keys:
- img - img
- gt_bboxes - gt_bboxes
- gt_semantic_seg - gt_seg_map
- gt_keypoints - gt_keypoints
Added Keys: Added Keys:
...@@ -1335,7 +1306,7 @@ class RandomResize(BaseTransform): ...@@ -1335,7 +1306,7 @@ class RandomResize(BaseTransform):
scale = int(scale[0] * ratio), int(scale[1] * ratio) scale = int(scale[0] * ratio), int(scale[1] * ratio)
return scale return scale
@cacheable_method @cache_randomness
def _random_scale(self) -> tuple: def _random_scale(self) -> tuple:
"""Private function to randomly sample an scale according to the type """Private function to randomly sample an scale according to the type
of ``scale``. of ``scale``.
......
...@@ -56,7 +56,7 @@ class TestResize: ...@@ -56,7 +56,7 @@ class TestResize:
def test_resize(self): def test_resize(self):
data_info = dict( data_info = dict(
img=np.random.random((1333, 800, 3)), img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)), gt_seg_map=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]), gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]])) gt_keypoints=np.array([[[20, 50, 1]]]))
...@@ -100,7 +100,7 @@ class TestResize: ...@@ -100,7 +100,7 @@ class TestResize:
results = transform(copy.deepcopy(data_info)) results = transform(copy.deepcopy(data_info))
assert (results['gt_bboxes'] == np.array([[0, 0, 168, 224]])).all() assert (results['gt_bboxes'] == np.array([[0, 0, 168, 224]])).all()
assert (results['gt_keypoints'] == np.array([[[30, 100, 1]]])).all() assert (results['gt_keypoints'] == np.array([[[30, 100, 1]]])).all()
assert results['gt_semantic_seg'].shape[:2] == (2666, 1200) assert results['gt_seg_map'].shape[:2] == (2666, 1200)
# test clip_object_border = False # test clip_object_border = False
data_info = dict( data_info = dict(
...@@ -143,39 +143,39 @@ class TestPad: ...@@ -143,39 +143,39 @@ class TestPad:
data_info = dict( data_info = dict(
img=np.random.random((1333, 800, 3)), img=np.random.random((1333, 800, 3)),
gt_semantic_seg=np.random.random((1333, 800, 3)), gt_seg_map=np.random.random((1333, 800, 3)),
gt_bboxes=np.array([[0, 0, 112, 112]]), gt_bboxes=np.array([[0, 0, 112, 112]]),
gt_keypoints=np.array([[[20, 50, 1]]])) gt_keypoints=np.array([[[20, 50, 1]]]))
# test pad img / gt_semantic_seg with size # test pad img / gt_seg_map with size
trans = Pad(size=(1200, 2000)) trans = Pad(size=(1200, 2000))
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (2000, 1200) assert results['img'].shape[:2] == (2000, 1200)
assert results['gt_semantic_seg'].shape[:2] == (2000, 1200) assert results['gt_seg_map'].shape[:2] == (2000, 1200)
# test pad img/gt_semantic_seg with size_divisor # test pad img/gt_seg_map with size_divisor
trans = Pad(size_divisor=11) trans = Pad(size_divisor=11)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 803) assert results['img'].shape[:2] == (1342, 803)
assert results['gt_semantic_seg'].shape[:2] == (1342, 803) assert results['gt_seg_map'].shape[:2] == (1342, 803)
# test pad img/gt_semantic_seg with pad_to_square # test pad img/gt_seg_map with pad_to_square
trans = Pad(pad_to_square=True) trans = Pad(pad_to_square=True)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1333, 1333) assert results['img'].shape[:2] == (1333, 1333)
assert results['gt_semantic_seg'].shape[:2] == (1333, 1333) assert results['gt_seg_map'].shape[:2] == (1333, 1333)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor # test pad img/gt_seg_map with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11) trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342) assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342) assert results['gt_seg_map'].shape[:2] == (1342, 1342)
# test pad img/gt_semantic_seg with pad_to_square and size_divisor # test pad img/gt_seg_map with pad_to_square and size_divisor
trans = Pad(pad_to_square=True, size_divisor=11) trans = Pad(pad_to_square=True, size_divisor=11)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert results['img'].shape[:2] == (1342, 1342) assert results['img'].shape[:2] == (1342, 1342)
assert results['gt_semantic_seg'].shape[:2] == (1342, 1342) assert results['gt_seg_map'].shape[:2] == (1342, 1342)
# test padding_mode # test padding_mode
new_img = np.ones((1333, 800, 3)) new_img = np.ones((1333, 800, 3))
...@@ -191,13 +191,12 @@ class TestPad: ...@@ -191,13 +191,12 @@ class TestPad:
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10))) pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all() assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000, :] == 10).all() assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 10).all()
trans = Pad(size=(2000, 2000), pad_val=dict(img=(12, 12, 12))) trans = Pad(size=(2000, 2000), pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all() assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 255).all()
800:2000, :] == 255).all()
# test rgb image, pad_to_square=True # test rgb image, pad_to_square=True
trans = Pad( trans = Pad(
...@@ -205,29 +204,28 @@ class TestPad: ...@@ -205,29 +204,28 @@ class TestPad:
pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10))) pad_val=dict(img=(12, 12, 12), seg=(10, 10, 10)))
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all() assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 10).all() assert (results['gt_seg_map'][:, 800:1333, :] == 10).all()
trans = Pad(pad_to_square=True, pad_val=dict(img=(12, 12, 12))) trans = Pad(pad_to_square=True, pad_val=dict(img=(12, 12, 12)))
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][:, 800:1333, :] == 12).all() assert (results['img'][:, 800:1333, :] == 12).all()
assert (results['gt_semantic_seg'][:, 800:1333, :] == 255).all() assert (results['gt_seg_map'][:, 800:1333, :] == 255).all()
# test pad_val is int # test pad_val is int
# test rgb image # test rgb image
trans = Pad(size=(2000, 2000), pad_val=12) trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000, :] == 12).all() assert (results['img'][1333:2000, 800:2000, :] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, assert (results['gt_seg_map'][1333:2000, 800:2000, :] == 255).all()
800:2000, :] == 255).all()
# test gray image # test gray image
new_img = np.random.random((1333, 800)) new_img = np.random.random((1333, 800))
data_info['img'] = new_img data_info['img'] = new_img
new_semantic_seg = np.random.random((1333, 800)) new_semantic_seg = np.random.random((1333, 800))
data_info['gt_semantic_seg'] = new_semantic_seg data_info['gt_seg_map'] = new_semantic_seg
trans = Pad(size=(2000, 2000), pad_val=12) trans = Pad(size=(2000, 2000), pad_val=12)
results = trans(copy.deepcopy(data_info)) results = trans(copy.deepcopy(data_info))
assert (results['img'][1333:2000, 800:2000] == 12).all() assert (results['img'][1333:2000, 800:2000] == 12).all()
assert (results['gt_semantic_seg'][1333:2000, 800:2000] == 255).all() assert (results['gt_seg_map'][1333:2000, 800:2000] == 255).all()
def test_repr(self): def test_repr(self):
trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge') trans = Pad(pad_to_square=True, size_divisor=11, padding_mode='edge')
...@@ -249,7 +247,7 @@ class TestCenterCrop: ...@@ -249,7 +247,7 @@ class TestCenterCrop:
@staticmethod @staticmethod
def reset_results(results, original_img, gt_semantic_map): def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img) results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map) results['gt_seg_map'] = copy.deepcopy(gt_semantic_map)
results['gt_bboxes'] = np.array([[0, 0, 210, 160], results['gt_bboxes'] = np.array([[0, 0, 210, 160],
[200, 150, 400, 300]]) [200, 150, 400, 300]])
results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]], results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]],
...@@ -296,9 +294,8 @@ class TestCenterCrop: ...@@ -296,9 +294,8 @@ class TestCenterCrop:
assert results['height'] == 224 assert results['height'] == 224
assert results['width'] == 224 assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all() assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert ( assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
results['gt_semantic_seg'] == self.gt_semantic_map[38:262, 88:312]).all()
88:312]).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224, np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all() 224]])).all()
...@@ -315,9 +312,8 @@ class TestCenterCrop: ...@@ -315,9 +312,8 @@ class TestCenterCrop:
assert results['height'] == 224 assert results['height'] == 224
assert results['width'] == 224 assert results['width'] == 224
assert (results['img'] == self.original_img[38:262, 88:312, ...]).all() assert (results['img'] == self.original_img[38:262, 88:312, ...]).all()
assert ( assert (results['gt_seg_map'] == self.gt_semantic_map[38:262,
results['gt_semantic_seg'] == self.gt_semantic_map[38:262, 88:312]).all()
88:312]).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224, np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all() 224]])).all()
...@@ -334,9 +330,8 @@ class TestCenterCrop: ...@@ -334,9 +330,8 @@ class TestCenterCrop:
assert results['height'] == 256 assert results['height'] == 256
assert results['width'] == 224 assert results['width'] == 224
assert (results['img'] == self.original_img[22:278, 88:312, ...]).all() assert (results['img'] == self.original_img[22:278, 88:312, ...]).all()
assert ( assert (results['gt_seg_map'] == self.gt_semantic_map[22:278,
results['gt_semantic_seg'] == self.gt_semantic_map[22:278, 88:312]).all()
88:312]).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 138], [112, 128, 224, np.array([[0, 0, 122, 138], [112, 128, 224,
256]])).all() 256]])).all()
...@@ -354,7 +349,7 @@ class TestCenterCrop: ...@@ -354,7 +349,7 @@ class TestCenterCrop:
assert results['height'] == 300 assert results['height'] == 300
assert results['width'] == 400 assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all() assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400, np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all() 300]])).all()
...@@ -372,7 +367,7 @@ class TestCenterCrop: ...@@ -372,7 +367,7 @@ class TestCenterCrop:
assert results['height'] == 300 assert results['height'] == 300
assert results['width'] == 400 assert results['width'] == 400
assert (results['img'] == self.original_img).all() assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all() assert (results['gt_seg_map'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400, np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all() 300]])).all()
...@@ -384,17 +379,17 @@ class TestCenterCrop: ...@@ -384,17 +379,17 @@ class TestCenterCrop:
transform = dict( transform = dict(
type='CenterCrop', type='CenterCrop',
crop_size=(img_width // 2, img_height * 2), crop_size=(img_width // 2, img_height * 2),
pad_mode='constant', auto_pad=True,
pad_val=12) pad_cfg=dict(type='Pad', padding_mode='constant', pad_val=12))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
results = center_crop_module(results) results = center_crop_module(results)
assert results['height'] == 600 assert results['height'] == 600
assert results['width'] == 200 assert results['width'] == 200
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape assert results['img'].shape[:2] == results['gt_seg_map'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all() assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all() assert (results['gt_seg_map'][300:600, 100:300] == 255).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200, np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all() 300]])).all()
...@@ -405,8 +400,11 @@ class TestCenterCrop: ...@@ -405,8 +400,11 @@ class TestCenterCrop:
transform = dict( transform = dict(
type='CenterCrop', type='CenterCrop',
crop_size=(img_width // 2, img_height * 2), crop_size=(img_width // 2, img_height * 2),
pad_mode='constant', auto_pad=True,
pad_val=dict(img=13, seg=33)) pad_cfg=dict(
type='Pad',
padding_mode='constant',
pad_val=dict(img=13, seg=33)))
center_crop_module = TRANSFORMS.build(transform) center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img, results = self.reset_results(results, self.original_img,
self.gt_semantic_map) self.gt_semantic_map)
...@@ -414,7 +412,7 @@ class TestCenterCrop: ...@@ -414,7 +412,7 @@ class TestCenterCrop:
assert results['height'] == 600 assert results['height'] == 600
assert results['width'] == 200 assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all() assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all() assert (results['gt_seg_map'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200, np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all() 300]])).all()
...@@ -432,9 +430,8 @@ class TestCenterCrop: ...@@ -432,9 +430,8 @@ class TestCenterCrop:
assert results['height'] == img_height assert results['height'] == img_height
assert results['width'] == img_width // 2 assert results['width'] == img_width // 2
assert (results['img'] == self.original_img[:, 100:300, ...]).all() assert (results['img'] == self.original_img[:, 100:300, ...]).all()
assert ( assert (results['gt_seg_map'] == self.gt_semantic_map[:,
results['gt_semantic_seg'] == self.gt_semantic_map[:, 100:300]).all()
100:300]).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200, np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all() 300]])).all()
...@@ -452,8 +449,8 @@ class TestCenterCrop: ...@@ -452,8 +449,8 @@ class TestCenterCrop:
assert results['height'] == img_height // 2 assert results['height'] == img_height // 2
assert results['width'] == img_width assert results['width'] == img_width
assert (results['img'] == self.original_img[75:225, ...]).all() assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225, assert (results['gt_seg_map'] == self.gt_semantic_map[75:225,
...]).all() ...]).all()
assert np.equal(results['gt_bboxes'], assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 85], [200, 75, 400, np.array([[0, 0, 210, 85], [200, 75, 400,
150]])).all() 150]])).all()
...@@ -479,7 +476,7 @@ class TestCenterCrop: ...@@ -479,7 +476,7 @@ class TestCenterCrop:
cropped_seg = center_crop_module(pil_seg) cropped_seg = center_crop_module(pil_seg)
cropped_seg = np.array(cropped_seg) cropped_seg = np.array(cropped_seg)
assert np.equal(results['img'], cropped_img).all() assert np.equal(results['img'], cropped_img).all()
assert np.equal(results['gt_semantic_seg'], cropped_seg).all() assert np.equal(results['gt_seg_map'], cropped_seg).all()
class TestRandomGrayscale: class TestRandomGrayscale:
...@@ -494,7 +491,7 @@ class TestRandomGrayscale: ...@@ -494,7 +491,7 @@ class TestRandomGrayscale:
type='RandomGrayscale', type='RandomGrayscale',
prob=1., prob=1.,
channel_weights=(0.299, 0.587, 0.114), channel_weights=(0.299, 0.587, 0.114),
keep_channel=True) keep_channels=True)
random_gray_scale_module = TRANSFORMS.build(transform) random_gray_scale_module = TRANSFORMS.build(transform)
assert isinstance(repr(random_gray_scale_module), str) assert isinstance(repr(random_gray_scale_module), str)
...@@ -511,7 +508,7 @@ class TestRandomGrayscale: ...@@ -511,7 +508,7 @@ class TestRandomGrayscale:
type='RandomGrayscale', type='RandomGrayscale',
prob=1., prob=1.,
channel_weights=(0.299, 0.587, 0.114), channel_weights=(0.299, 0.587, 0.114),
keep_channel=True) keep_channels=True)
random_gray_scale_module = TRANSFORMS.build(transform) random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img) results['img'] = copy.deepcopy(self.img)
...@@ -541,14 +538,14 @@ class TestRandomGrayscale: ...@@ -541,14 +538,14 @@ class TestRandomGrayscale:
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class MockFormatBundle(BaseTransform): class MockPackTaskInputs(BaseTransform):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def transform(self, results): def transform(self, results):
data_sample = Mock() packed_results = dict(inputs=results['img'], data_sample=Mock())
return results['img'], data_sample return packed_results
class TestMultiScaleFlipAug: class TestMultiScaleFlipAug:
...@@ -581,30 +578,28 @@ class TestMultiScaleFlipAug: ...@@ -581,30 +578,28 @@ class TestMultiScaleFlipAug:
# test with empty transforms # test with empty transforms
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[dict(type='MockFormatBundle')], transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)], img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results) packed_results = multi_scale_flip_aug_module(results)
assert len(input) == 12 assert len(packed_results['inputs']) == 12
assert len(data_sample) == 12
# test with flip=False # test with allow_flip=False
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[dict(type='MockFormatBundle')], transforms=[dict(type='MockPackTaskInputs')],
img_scale=[(1333, 800), (800, 600), (640, 480)], img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False, allow_flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results) packed_results = multi_scale_flip_aug_module(results)
assert len(input) == 3 assert len(packed_results['inputs']) == 3
assert len(data_sample) == 3
# test with transforms # test with transforms
img_norm_cfg = dict( img_norm_cfg = dict(
...@@ -615,20 +610,19 @@ class TestMultiScaleFlipAug: ...@@ -615,20 +610,19 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']), dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle') dict(type='MockPackTaskInputs')
] ]
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=transforms_cfg, transforms=transforms_cfg,
img_scale=[(1333, 800), (800, 600), (640, 480)], img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results) packed_results = multi_scale_flip_aug_module(results)
assert len(input) == 12 assert len(packed_results['inputs']) == 12
assert len(data_sample) == 12
# test with scale_factor # test with scale_factor
img_norm_cfg = dict( img_norm_cfg = dict(
...@@ -639,20 +633,19 @@ class TestMultiScaleFlipAug: ...@@ -639,20 +633,19 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']), dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle') dict(type='MockPackTaskInputs')
] ]
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=transforms_cfg, transforms=transforms_cfg,
scale_factor=[0.5, 1., 2.], scale_factor=[0.5, 1., 2.],
flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results) packed_results = multi_scale_flip_aug_module(results)
assert len(input) == 12 assert len(packed_results['inputs']) == 12
assert len(data_sample) == 12
# test no resize # test no resize
img_norm_cfg = dict( img_norm_cfg = dict(
...@@ -663,22 +656,21 @@ class TestMultiScaleFlipAug: ...@@ -663,22 +656,21 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']), dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle') dict(type='MockPackTaskInputs')
] ]
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=transforms_cfg, transforms=transforms_cfg,
flip=True, allow_flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results) packed_results = multi_scale_flip_aug_module(results)
assert len(input) == 4 assert len(packed_results['inputs']) == 4
assert len(data_sample) == 4
class TestRandomMultiscaleResize: class TestRandomChoiceResize:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
...@@ -688,25 +680,25 @@ class TestRandomMultiscaleResize: ...@@ -688,25 +680,25 @@ class TestRandomMultiscaleResize:
def reset_results(self, results): def reset_results(self, results):
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
results['gt_semantic_seg'] = copy.deepcopy(self.original_img) results['gt_seg_map'] = copy.deepcopy(self.original_img)
def test_repr(self): def test_repr(self):
# test repr # test repr
transform = dict( transform = dict(
type='RandomMultiscaleResize', scales=[(1333, 800), (1333, 600)]) type='RandomChoiceResize', scales=[(1333, 800), (1333, 600)])
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
assert isinstance(repr(random_multiscale_resize), str) assert isinstance(repr(random_multiscale_resize), str)
def test_error(self): def test_error(self):
# test assertion if size is smaller than 0 # test assertion if size is smaller than 0
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
transform = dict(type='RandomMultiscaleResize', scales=[0.5, 1, 2]) transform = dict(type='RandomChoiceResize', scales=[0.5, 1, 2])
TRANSFORMS.build(transform) TRANSFORMS.build(transform)
def test_random_multiscale_resize(self): def test_random_multiscale_resize(self):
results = dict() results = dict()
# test with one scale # test with one scale
transform = dict(type='RandomMultiscaleResize', scales=[(1333, 800)]) transform = dict(type='RandomChoiceResize', scales=[(1333, 800)])
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results) self.reset_results(results)
results = random_multiscale_resize(results) results = random_multiscale_resize(results)
...@@ -714,7 +706,7 @@ class TestRandomMultiscaleResize: ...@@ -714,7 +706,7 @@ class TestRandomMultiscaleResize:
# test with multi scales # test with multi scales
_scale_choice = [(1333, 800), (1333, 600)] _scale_choice = [(1333, 800), (1333, 600)]
transform = dict(type='RandomMultiscaleResize', scales=_scale_choice) transform = dict(type='RandomChoiceResize', scales=_scale_choice)
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results) self.reset_results(results)
results = random_multiscale_resize(results) results = random_multiscale_resize(results)
...@@ -723,9 +715,9 @@ class TestRandomMultiscaleResize: ...@@ -723,9 +715,9 @@ class TestRandomMultiscaleResize:
# test keep_ratio # test keep_ratio
transform = dict( transform = dict(
type='RandomMultiscaleResize', type='RandomChoiceResize',
scales=[(900, 600)], scales=[(900, 600)],
keep_ratio=True) resize_cfg=dict(type='Resize', keep_ratio=True))
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results) self.reset_results(results)
_input_ratio = results['img'].shape[0] / results['img'].shape[1] _input_ratio = results['img'].shape[0] / results['img'].shape[1]
...@@ -736,9 +728,9 @@ class TestRandomMultiscaleResize: ...@@ -736,9 +728,9 @@ class TestRandomMultiscaleResize:
# test clip_object_border # test clip_object_border
gt_bboxes = [[200, 150, 600, 450]] gt_bboxes = [[200, 150, 600, 450]]
transform = dict( transform = dict(
type='RandomMultiscaleResize', type='RandomChoiceResize',
scales=[(200, 150)], scales=[(200, 150)],
clip_object_border=True) resize_cfg=dict(type='Resize', clip_object_border=True))
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results) self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes) results['gt_bboxes'] = np.array(gt_bboxes)
...@@ -748,9 +740,9 @@ class TestRandomMultiscaleResize: ...@@ -748,9 +740,9 @@ class TestRandomMultiscaleResize:
150]])).all() 150]])).all()
transform = dict( transform = dict(
type='RandomMultiscaleResize', type='RandomChoiceResize',
scales=[(200, 150)], scales=[(200, 150)],
clip_object_border=False) resize_cfg=dict(type='Resize', clip_object_border=False))
random_multiscale_resize = TRANSFORMS.build(transform) random_multiscale_resize = TRANSFORMS.build(transform)
self.reset_results(results) self.reset_results(results)
results['gt_bboxes'] = np.array(gt_bboxes) results['gt_bboxes'] = np.array(gt_bboxes)
...@@ -873,7 +865,7 @@ class TestRandomResize: ...@@ -873,7 +865,7 @@ class TestRandomResize:
# keep ratio is True # keep ratio is True
results = { results = {
'img': np.random.random((224, 224, 3)), 'img': np.random.random((224, 224, 3)),
'gt_semantic_seg': np.random.random((224, 224, 3)), 'gt_seg_map': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 0, 112, 112]]), 'gt_bboxes': np.array([[0, 0, 112, 112]]),
'gt_keypoints': np.array([[[112, 112]]]) 'gt_keypoints': np.array([[[112, 112]]])
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment