Commit 2844a116 authored by Yifei Yang's avatar Yifei Yang Committed by zhouzaida
Browse files

[Fix] Fix MultiScaleFlipAug (#1801)

* Fix MultiScaleFlipAug

* fix as comment
parent 169f098d
......@@ -762,6 +762,8 @@ class MultiScaleFlipAug(BaseTransform):
transforms (list[dict]): Transforms to be applied to each resized
and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
......@@ -778,6 +780,7 @@ class MultiScaleFlipAug(BaseTransform):
self,
transforms: List[dict],
img_scale: Optional[Union[Tuple, List[Tuple]]] = None,
scale_factor: Optional[Union[float, List[float]]] = None,
flip: bool = False,
flip_direction: Union[str, List[str]] = 'horizontal',
resize_cfg: dict = dict(type='Resize', keep_ratio=True),
......@@ -785,11 +788,20 @@ class MultiScaleFlipAug(BaseTransform):
) -> None:
super().__init__()
self.transforms = Compose(transforms) # type: ignore
assert img_scale is not None
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
if img_scale is not None:
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
else:
# if ``img_scale`` and ``scale_factor`` both be ``None``
if scale_factor is None:
self.img_scale = [1.]
else:
self.img_scale = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'
self.flip = flip
self.flip_direction = flip_direction if isinstance(
......@@ -801,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
self.resize_cfg = resize_cfg
self.flip_cfg = flip_cfg
def transform(self, results: dict) -> dict:
def transform(self, results: dict) -> Tuple[List, List]:
"""Apply test time augment transforms on results.
Args:
......@@ -813,6 +825,7 @@ class MultiScaleFlipAug(BaseTransform):
"""
aug_data = []
input_data = []
flip_args = [(False, '')]
if self.flip:
flip_args += [(True, direction)
......@@ -820,7 +833,7 @@ class MultiScaleFlipAug(BaseTransform):
for scale in self.img_scale:
for flip, direction in flip_args:
_resize_cfg = self.resize_cfg.copy()
_resize_cfg.update(scale=scale)
_resize_cfg.update({self.scale_key: scale})
_resize_flip = [_resize_cfg]
if flip:
......@@ -834,14 +847,11 @@ class MultiScaleFlipAug(BaseTransform):
resize_flip = Compose(_resize_flip)
_results = results.copy()
_results = resize_flip(_results)
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
input_image, data_sample = self.transforms(_results)
input_data.append(input_image)
aug_data.append(data_sample)
return input_data, aug_data
def __repr__(self) -> str:
repr_str = self.__class__.__name__
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from unittest.mock import Mock
import numpy as np
import pytest
......@@ -8,6 +9,7 @@ import pytest
import mmcv
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize)
from mmcv.transforms.base import BaseTransform
try:
import torch
......@@ -538,6 +540,17 @@ class TestRandomGrayscale:
assert img.shape == (10, 10, 1)
@TRANSFORMS.register_module()
class MockFormatBundle(BaseTransform):
def __init__(self) -> None:
super().__init__()
def transform(self, results):
data_sample = Mock()
return results['img'], data_sample
class TestMultiScaleFlipAug:
@classmethod
......@@ -547,12 +560,6 @@ class TestMultiScaleFlipAug:
cls.original_img = copy.deepcopy(cls.img)
def test_error(self):
# test assertion if img_scale is None
with pytest.raises(AssertionError):
transform = dict(
type='MultiScaleFlipAug', img_scale=None, transforms=[])
TRANSFORMS.build(transform)
# test assertion if img_scale is not tuple or list of tuple
with pytest.raises(AssertionError):
transform = dict(
......@@ -574,28 +581,30 @@ class TestMultiScaleFlipAug:
# test with empty transforms
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
transforms=[dict(type='MockFormatBundle')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
# test with flip=False
transform = dict(
type='MultiScaleFlipAug',
transforms=[],
transforms=[dict(type='MockFormatBundle')],
img_scale=[(1333, 800), (800, 600), (640, 480)],
flip=False,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 3
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 3
assert len(data_sample) == 3
# test with transforms
img_norm_cfg = dict(
......@@ -606,6 +615,7 @@ class TestMultiScaleFlipAug:
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
]
transform = dict(
type='MultiScaleFlipAug',
......@@ -616,8 +626,56 @@ class TestMultiScaleFlipAug:
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
results = multi_scale_flip_aug_module(results)
assert len(results['img']) == 12
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
# test with scale_factor
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
scale_factor=[0.5, 1., 2.],
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 12
assert len(data_sample) == 12
# test no resize
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transforms_cfg = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='MockFormatBundle')
]
transform = dict(
type='MultiScaleFlipAug',
transforms=transforms_cfg,
flip=True,
flip_direction=['horizontal', 'vertical', 'diagonal'])
multi_scale_flip_aug_module = TRANSFORMS.build(transform)
results = dict()
results['img'] = copy.deepcopy(self.original_img)
input, data_sample = multi_scale_flip_aug_module(results)
assert len(input) == 4
assert len(data_sample) == 4
class TestRandomMultiscaleResize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment