Commit 2844a116 authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Fix MultiScaleFlipAug (#1801)

* Fix MultiScaleFlipAug

* fix as comment
parent 169f098d
...@@ -762,6 +762,8 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -762,6 +762,8 @@ class MultiScaleFlipAug(BaseTransform):
transforms (list[dict]): Transforms to be applied to each resized transforms (list[dict]): Transforms to be applied to each resized
and flipped data. and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing. img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
flip (bool): Whether apply flip augmentation. Defaults to False. flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions, flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If options are "horizontal", "vertical" and "diagonal". If
...@@ -778,6 +780,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -778,6 +780,7 @@ class MultiScaleFlipAug(BaseTransform):
self, self,
transforms: List[dict], transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None, img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None,
flip: bool = False, flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal', flip_direction: Union[str, List[str]] = 'horizontal',
resize_cfg: dict = dict(type='Resize', keep_ratio=True), resize_cfg: dict = dict(type='Resize', keep_ratio=True),
...@@ -785,11 +788,20 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -785,11 +788,20 @@ class MultiScaleFlipAug(BaseTransform):
) -> None: ) -> None:
super().__init__() super().__init__()
self.transforms = Compose(transforms) # type: ignore self.transforms = Compose(transforms) # type: ignore
assert img_scale is not None
if img_scale is not None:
self.img_scale = img_scale if isinstance(img_scale, self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
self.scale_key = 'scale' self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple) assert mmcv.is_list_of(self.img_scale, tuple)
else:
# if ``img_scale`` and ``scale_factor`` both be ``None``
if scale_factor is None:
self.img_scale = [1.]
else:
self.img_scale = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'
self.flip = flip self.flip = flip
self.flip_direction = flip_direction if isinstance( self.flip_direction = flip_direction if isinstance(
...@@ -801,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -801,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
self.resize_cfg = resize_cfg self.resize_cfg = resize_cfg
self.flip_cfg = flip_cfg self.flip_cfg = flip_cfg
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> Tuple[List, List]:
"""Apply test time augment transforms on results. """Apply test time augment transforms on results.
Args: Args:
...@@ -813,6 +825,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -813,6 +825,7 @@ class MultiScaleFlipAug(BaseTransform):
""" """
aug_data = [] aug_data = []
input_data = []
flip_args = [(False, '')] flip_args = [(False, '')]
if self.flip: if self.flip:
flip_args += [(True, direction) flip_args += [(True, direction)
...@@ -820,7 +833,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -820,7 +833,7 @@ class MultiScaleFlipAug(BaseTransform):
for scale in self.img_scale: for scale in self.img_scale:
for flip, direction in flip_args: for flip, direction in flip_args:
_resize_cfg = self.resize_cfg.copy() _resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(scale=scale) _resize_cfg.update({self.scale_key: scale})
_resize_flip = [_resize_cfg] _resize_flip = [_resize_cfg]
if flip: if flip:
...@@ -834,14 +847,11 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -834,14 +847,11 @@ class MultiScaleFlipAug(BaseTransform):
resize_flip = Compose(_resize_flip) resize_flip = Compose(_resize_flip)
_results = results.copy() _results = results.copy()
_results = resize_flip(_results) _results = resize_flip(_results)
data = self.transforms(_results) input_image, data_sample = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list input_data.append(input_image)
aug_data_dict = {key: [] for key in aug_data[0]} aug_data.append(data_sample)
for data in aug_data: return input_data, aug_data
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import os.path as osp import os.path as osp
from unittest.mock import Mock
import numpy as np import numpy as np
import pytest import pytest
...@@ -8,6 +9,7 @@ import pytest ...@@ -8,6 +9,7 @@ import pytest
import mmcv import mmcv
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip, from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize) RandomResize, Resize)
from mmcv.transforms.base import BaseTransform
try: try:
import torch import torch
...@@ -538,6 +540,17 @@ class TestRandomGrayscale: ...@@ -538,6 +540,17 @@ class TestRandomGrayscale:
assert img.shape == (10, 10, 1) assert img.shape == (10, 10, 1)
@TRANSFORMS.register_module()
class MockFormatBundle(BaseTransform):
def __init__(self) -> None:
super().__init__()
def transform(self, results):
data_sample = Mock()
return results['img'], data_sample
class TestMultiScaleFlipAug: class TestMultiScaleFlipAug:
@classmethod @classmethod
...@@ -547,12 +560,6 @@ class TestMultiScaleFlipAug: ...@@ -547,12 +560,6 @@ class TestMultiScaleFlipAug:
cls.original_img = copy.deepcopy(cls.img) cls.original_img = copy.deepcopy(cls.img)
def test_error(self): def test_error(self):
# test assertion if img_scale is None
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=None, transforms=[])
TRANSFORMS.build(transform)
# test assertion if img_scale is not tuple or list of tuple # test assertion if img_scale is not tuple or list of tuple
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
transform = dict( transform = dict(
...@@ -574,28 +581,30 @@ class TestMultiScaleFlipAug: ...@@ -574,28 +581,30 @@ class TestMultiScaleFlipAug:
# test with empty transforms # test with empty transforms
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[], transforms=[dict(type='MockFormatBundle')],
img_scale=[(1333, 800), (800, 600), (640, 480)], img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True, flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results) input, data_sample = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12 assert len(input) == 12
assert len(data_sample) == 12
# test with flip=False # test with flip=False
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
transforms=[], transforms=[dict(type='MockFormatBundle')],
img_scale=[(1333, 800), (800, 600), (640, 480)], img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False, flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal']) flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results) input, data_sample = multi_scale_flip_aug_module(results)
assert len(results['img']) == 3 assert len(input) == 3
assert len(data_sample) == 3
# test with transforms # test with transforms
img_norm_cfg = dict( img_norm_cfg = dict(
...@@ -606,6 +615,7 @@ class TestMultiScaleFlipAug: ...@@ -606,6 +615,7 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32), dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']), dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
] ]
transform = dict( transform = dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
...@@ -616,8 +626,56 @@ class TestMultiScaleFlipAug: ...@@ -616,8 +626,56 @@ class TestMultiScaleFlipAug:
multi_scale_flip_aug_module = TRANSFORMS.build(transform) multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict() results = dict()
results['img'] = copy.deepcopy(self.original_img) results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results) input, data_sample = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12 assert len(input) == 12
assert len(data_sample) == 12
# test with scale_factor
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
scale_factor=[0.5, 1., 2.],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
# test no resize
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 4
assert len(data_sample) == 4
class TestRandomMultiscaleResize: class TestRandomMultiscaleResize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment