Commit ff0dfb74 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

add RandomApply (#1863)

parent 46cb4b10
......@@ -172,7 +172,7 @@ pipeline = [
利用字段映射包装,我们在实现数据变换类时,不需要考虑在 `transform` 方法中考虑各种可能的输入字段名,只需要处理默认的字段即可。
### 随机选择(RandomChoice)
### 随机选择(RandomChoice)和随机执行(RandomApply)
随机选择包装(`RandomChoice`)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。
......@@ -198,6 +198,18 @@ pipeline = [
]
```
随机执行包装(`RandomApply`)用于以指定概率随机执行数据变换组合。例如:
```python
pipeline = [
...
dict(type='RandomApply',
transforms=[dict(type='Rotate', angle=30.)],
prob=0.3) # 以 0.3 的概率执行被包装的数据变换
...
]
```
### 多目标扩展(TransformBroadcaster)
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `KeyMapper` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`TransformBroadcaster`)。
......
......@@ -5,7 +5,8 @@ from .loading import LoadAnnotations, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomChoiceResize, RandomFlip, RandomGrayscale,
RandomResize, Resize)
from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
from .wrappers import (Compose, KeyMapper, RandomApply, RandomChoice,
TransformBroadcaster)
try:
import torch # noqa: F401
......@@ -14,7 +15,8 @@ except ImportError:
'BaseTransform', 'TRANSFORMS', 'TransformBroadcaster', 'Compose',
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize',
'CenterCrop', 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
'CenterCrop', 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize',
'RandomApply'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
......@@ -24,5 +26,5 @@ else:
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
'MultiScaleFlipAug', 'RandomResize', 'RandomApply'
]
......@@ -50,6 +50,8 @@ class Compose(BaseTransform):
"""
def __init__(self, transforms: Union[Transform, List[Transform]]):
super().__init__()
if not isinstance(transforms, list):
transforms = [transforms]
self.transforms = []
......@@ -123,7 +125,7 @@ class KeyMapper(BaseTransform):
>>> # 'gt_img' to inner (used by inner transforms) filed name
>>> # 'img'
>>> dict(type='KeyMapper',
>>> mapping=dict(img='gt_img'),
>>> mapping={'img': 'gt_img'},
>>> # auto_remap=True means output key mapping is the revert of
>>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # back to outer 'gt_img'
......@@ -158,6 +160,8 @@ class KeyMapper(BaseTransform):
auto_remap: Optional[bool] = None,
allow_nonexist_keys: bool = False):
super().__init__()
self.allow_nonexist_keys = allow_nonexist_keys
self.mapping = mapping
......@@ -318,7 +322,7 @@ class TransformBroadcaster(KeyMapper):
>>> # respectively
>>> dict(type='TransformBroadcaster',
>>> # case 1: from multiple outer fields
>>> mapping=dict(img=['lq', 'gt']),
>>> mapping={'img': ['lq', 'gt']},
>>> auto_remap=True,
>>> # share_random_param=True means using identical random
>>> # parameters in every processing
......@@ -338,7 +342,7 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='TransformBroadcaster',
>>> # case 2: from one outer field that contains multiple
>>> # data elements (e.g. a list)
>>> # mapping=dict(img='images'),
>>> # mapping={'img': 'images'},
>>> auto_remap=True,
>>> share_random_param=True,
>>> transforms=[
......@@ -420,10 +424,10 @@ class TransformBroadcaster(KeyMapper):
@TRANSFORMS.register_module()
class RandomChoice(BaseTransform):
"""Process data with a randomly chosen pipeline from given candidates.
"""Process data with a randomly chosen transform from given candidates.
Args:
transforms (list[list]): A list of pipeline candidates, each is a
transforms (list[list]): A list of transform candidates, each is a
sequence of transforms.
prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
......@@ -446,6 +450,8 @@ class RandomChoice(BaseTransform):
transforms: List[Union[Transform, List[Transform]]],
prob: Optional[List[float]] = None):
super().__init__()
if prob is not None:
assert mmcv.is_seq_of(prob, float)
assert len(transforms) == len(prob), \
......@@ -467,3 +473,42 @@ class RandomChoice(BaseTransform):
def transform(self, results):
idx = self.random_pipeline_index()
return self.transforms[idx](results)
@TRANSFORMS.register_module()
class RandomApply(BaseTransform):
"""Apply transforms randomly with a given probability.
Args:
transforms (list[dict | callable]): The transform or transform list
to randomly apply.
prob (float): The probability to apply transforms. Default: 0.5
Examples:
>>> # config
>>> pipeline = [
>>> dict(type='RandomApply',
>>> transforms=[dict(type='HorizontalFlip')],
>>> prob=0.3)
>>> ]
"""
def __init__(self,
transforms: Union[Transform, List[Transform]],
prob: float = 0.5):
super().__init__()
self.prob = prob
self.transforms = Compose(transforms)
def __iter__(self):
return iter(self.transforms)
@cache_randomness
def random_apply(self):
return np.random.rand() < self.prob
def transform(self, results: Dict) -> Dict:
if self.random_apply():
results = self.transforms(results)
return results
......@@ -7,8 +7,8 @@ import pytest
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomChoice,
TransformBroadcaster)
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
RandomChoice, TransformBroadcaster)
@TRANSFORMS.register_module()
......@@ -136,13 +136,13 @@ def test_cache_random_parameters():
np.testing.assert_equal(results_1['value'], results_2['value'])
def test_apply_to_mapped():
def test_key_mapper():
# Case 1: simple remap
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
mapping={'value': 'v_in'},
remapping={'value': 'v_out'})
results = dict(value=0, v_in=1)
results = pipeline(results)
......@@ -154,8 +154,8 @@ def test_apply_to_mapped():
# Case 2: collecting list
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping=dict(value=['v_in_1', 'v_in_2']),
remapping=dict(value=['v_out_1', 'v_out_2']))
mapping={'value': ['v_in_1', 'v_in_2']},
remapping={'value': ['v_out_1', 'v_out_2']})
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -170,8 +170,14 @@ def test_apply_to_mapped():
# Case 3: collecting dict
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
remapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
mapping={'value': {
'v1': 'v_in_1',
'v2': 'v_in_2'
}},
remapping={'value': {
'v1': 'v_out_1',
'v2': 'v_out_2'
}})
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
......@@ -332,7 +338,7 @@ def test_apply_to_multiple():
_ = str(pipeline)
def test_randomchoice():
def test_random_choice():
# Case 1: given probability
pipeline = RandomChoice(
......@@ -355,7 +361,32 @@ def test_randomchoice():
transforms=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
mapping=dict(value='values'),
mapping={'value': 'values'},
auto_remap=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
results = pipeline(results)
# check share_random_params=True works so that all values are same
values = results['values']
assert all(map(lambda x: x == values[0], values))
def test_random_apply():
# Case 1: simple use
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 1.0)
# Case 2: nested RandomApply in TransformBroadcaster
pipeline = TransformBroadcaster(
transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
mapping={'value': 'values'},
auto_remap=True,
share_random_params=True)
......@@ -365,6 +396,10 @@ def test_randomchoice():
values = results['values']
assert all(map(lambda x: x == values[0], values))
# __iter__
for _ in pipeline:
pass
def test_utils():
# Test cache_randomness: normal case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment