Unverified Commit a6c42ad3 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Feature] Add TTA transform (#2146)



* Add TestRandomResize

* rename ut class

* minor retine

* Replace for loop with itertools.product

* Support accept built transforms

* Fix unit test

* Refine docstring

* minor refine
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Minor refine
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent bf48ca03
...@@ -4,7 +4,7 @@ from .builder import TRANSFORMS ...@@ -4,7 +4,7 @@ from .builder import TRANSFORMS
from .loading import LoadAnnotations, LoadImageFromFile from .loading import LoadAnnotations, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad, from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomChoiceResize, RandomFlip, RandomGrayscale, RandomChoiceResize, RandomFlip, RandomGrayscale,
RandomResize, Resize) RandomResize, Resize, TestTimeAug)
from .wrappers import (Compose, KeyMapper, RandomApply, RandomChoice, from .wrappers import (Compose, KeyMapper, RandomApply, RandomChoice,
TransformBroadcaster) TransformBroadcaster)
...@@ -16,7 +16,7 @@ except ImportError: ...@@ -16,7 +16,7 @@ except ImportError:
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations', 'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize', 'Normalize', 'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize',
'CenterCrop', 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize', 'CenterCrop', 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize',
'RandomApply' 'RandomApply', 'TestTimeAug'
] ]
else: else:
from .formatting import ImageToTensor, ToTensor, to_tensor from .formatting import ImageToTensor, ToTensor, to_tensor
...@@ -26,5 +26,5 @@ else: ...@@ -26,5 +26,5 @@ else:
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations', 'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor', 'Normalize', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale', 'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize', 'RandomApply' 'MultiScaleFlipAug', 'RandomResize', 'RandomApply', 'TestTimeAug'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import random import random
import warnings import warnings
from itertools import product
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mmengine import mmengine
...@@ -746,7 +748,7 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -746,7 +748,7 @@ class MultiScaleFlipAug(BaseTransform):
- resize to (1333, 800) + flip - resize to (1333, 800) + flip
The four results are then transformed with ``transforms`` argument. The four results are then transformed with ``transforms`` argument.
After that, results are wrapped into lists of the same length as followed: After that, results are wrapped into lists of the same length as below:
.. code-block:: .. code-block::
...@@ -870,6 +872,130 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -870,6 +872,130 @@ class MultiScaleFlipAug(BaseTransform):
return repr_str return repr_str
@TRANSFORMS.register_module()
class TestTimeAug(BaseTransform):
"""Test-time augmentation transform.
An example configuration is as followed:
.. code-block::
dict(type='TestTimeAug',
transforms=[
[dict(type='Resize', scale=(1333, 400), keep_ratio=True),
dict(type='Resize', scale=(1333, 800), keep_ratio=True)],
[dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)],
[dict(type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))]])
``results`` will be transformed using all transforms defined in
``transforms`` arguments.
For the above configuration, there are four combinations of resize
and flip:
- Resize to (1333, 400) + no flip
- Resize to (1333, 400) + flip
- Resize to (1333, 800) + no flip
- resize to (1333, 800) + flip
After that, results are wrapped into lists of the same length as below:
.. code-block::
dict(
inputs=[...],
data_samples=[...]
)
The length of ``inputs`` and ``data_samples`` are both 4.
Required Keys:
- Depending on the requirements of the ``transforms`` parameter.
Modified Keys:
- All output keys of each transform.
Args:
transforms (list[list[dict]]): Transforms to be applied to data sampled
from dataset. ``transforms`` is a list of list, and each list
element usually represents a series of transforms with the same
type and different arguments. Data will be processed by each list
elements sequentially. See more information in :meth:`transform`.
"""
def __init__(self, transforms: list):
for i, transform_list in enumerate(transforms):
for j, transform in enumerate(transform_list):
if isinstance(transform, dict):
transform_list[j] = TRANSFORMS.build(transform)
elif callable(transform):
continue
else:
raise TypeError(
'transform must be callable or a dict, but got'
f' {type(transform)}')
transforms[i] = transform_list
self.subroutines = [
Compose(subroutine) for subroutine in product(*transforms)
]
def transform(self, results: dict) -> dict:
"""Apply all transforms defined in :attr:`transforms` to the results.
As the example given in :obj:`TestTimeAug`, ``transforms`` consists of
2 ``Resize``, 2 ``RandomFlip`` and 1 ``PackDetInputs``.
The data sampled from dataset will be processed as follows:
1. Data will be processed by 2 ``Resize`` and return a list
of 2 results.
2. Each result in list will be further passed to 2
``RandomFlip``, and aggregates into a list of 4 results.
3. Each result will be processed by ``PackDetInputs``, and
return a list of dict.
4. Aggregates the same fields of results, and finally returns
a dict. Each value of the dict represents 4 transformed
results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict: The augmented data, where each value is wrapped
into a list.
"""
results_list = [] # type: ignore
for subroutine in self.subroutines:
result = subroutine(copy.deepcopy(results))
assert isinstance(result, dict), (
f'Data processed by {subroutine} must return a dict, but got '
f'{result}')
assert result is not None, (
f'Data processed by {subroutine} in `TestTimeAug` should not '
'be None! Please check your validation dataset and the '
f'transforms in {subroutine}')
results_list.append(result)
aug_data_dict = {
key: [item[key] for item in results_list] # type: ignore
for key in results_list[0] # type: ignore
}
return aug_data_dict
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += 'transforms=\n'
for subroutine in self.subroutines:
repr_str += f'{repr(subroutine)}\n'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomChoiceResize(BaseTransform): class RandomChoiceResize(BaseTransform):
"""Resize images & bbox & mask from a list of multiple scales. """Resize images & bbox & mask from a list of multiple scales.
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import mmcv import mmcv
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip, from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize) RandomResize, Resize, TestTimeAug)
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
try: try:
...@@ -926,3 +926,89 @@ class TestRandomResize: ...@@ -926,3 +926,89 @@ class TestRandomResize:
resize_type='Resize', resize_type='Resize',
keep_ratio=True) keep_ratio=True)
results_update = TRANSFORMS.transform(copy.deepcopy(results)) results_update = TRANSFORMS.transform(copy.deepcopy(results))
class TestTestTimeAug:
def test_init(self):
subroutines = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(1333, 400), keep_ratio=True)
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]
tta_transform = TestTimeAug(subroutines)
subroutines = tta_transform.subroutines
assert len(subroutines) == 4
assert isinstance(subroutines[0].transforms[0], Resize)
assert isinstance(subroutines[0].transforms[1], RandomFlip)
assert isinstance(subroutines[0].transforms[2], Normalize)
assert isinstance(subroutines[1].transforms[0], Resize)
assert isinstance(subroutines[1].transforms[1], RandomFlip)
assert isinstance(subroutines[1].transforms[2], Normalize)
def test_transform(self):
results = {
'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_seg_map': np.random.random((224, 224, 3))
}
input_results = copy.deepcopy(results)
transforms = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(1333, 400), keep_ratio=True)
], [
dict(type='RandomFlip', prob=0.),
dict(type='RandomFlip', prob=1.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]
tta_transform = TestTimeAug(transforms)
results = tta_transform.transform(results)
assert len(results['img']) == 4
resize1 = tta_transform.subroutines[0].transforms[0]
resize2 = tta_transform.subroutines[2].transforms[0]
flip1 = tta_transform.subroutines[0].transforms[1]
flip2 = tta_transform.subroutines[1].transforms[1]
normalize = tta_transform.subroutines[0].transforms[2]
target_results = [
normalize.transform(
flip1.transform(
resize1.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip2.transform(
resize1.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip1.transform(
resize2.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip2.transform(
resize2.transform(copy.deepcopy(input_results)))),
]
assert np.allclose(target_results[0]['img'], results['img'][0])
assert np.allclose(target_results[1]['img'], results['img'][1])
assert np.allclose(target_results[2]['img'], results['img'][2])
assert np.allclose(target_results[3]['img'], results['img'][3])
def test_repr(self):
transforms = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(1333, 400), keep_ratio=True)
], [
dict(type='RandomFlip', prob=0.),
dict(type='RandomFlip', prob=1.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]
tta_transform = TestTimeAug(transforms)
repr_str = repr(tta_transform)
repr_str_list = repr_str.split('\n')
assert repr_str_list[0] == 'TestTimeAugtransforms='
assert repr_str_list[1] == 'Compose('
assert repr_str_list[2].startswith(' Resize(scale=(1333, 800)')
assert repr_str_list[3].startswith(' RandomFlip(prob=0.0')
assert repr_str_list[4].startswith(' Normalize(mean=[0. 0. 0.]')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment