Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
ff0dfb74
Commit
ff0dfb74
authored
Apr 20, 2022
by
Yining Li
Committed by
zhouzaida
Jul 19, 2022
Browse files
add RandomApply (#1863)
parent
46cb4b10
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
20 deletions
+114
-20
docs/zh_cn/understand_mmcv/data_transform.md
docs/zh_cn/understand_mmcv/data_transform.md
+13
-1
mmcv/transforms/__init__.py
mmcv/transforms/__init__.py
+5
-3
mmcv/transforms/wrappers.py
mmcv/transforms/wrappers.py
+50
-5
tests/test_transforms/test_transforms_wrapper.py
tests/test_transforms/test_transforms_wrapper.py
+46
-11
No files found.
docs/zh_cn/understand_mmcv/data_transform.md
View file @
ff0dfb74
...
@@ -172,7 +172,7 @@ pipeline = [
...
@@ -172,7 +172,7 @@ pipeline = [
利用字段映射包装,我们在实现数据变换类时,不需要考虑在
`transform`
方法中考虑各种可能的输入字段名,只需要处理默认的字段即可。
利用字段映射包装,我们在实现数据变换类时,不需要考虑在
`transform`
方法中考虑各种可能的输入字段名,只需要处理默认的字段即可。
### 随机选择(RandomChoice)
### 随机选择(RandomChoice)
和随机执行(RandomApply)
随机选择包装(
`RandomChoice`
)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。
随机选择包装(
`RandomChoice`
)用于从一系列数据变换组合中随机应用一个数据变换组合。利用这一包装,我们可以简单地实现一些数据增强功能,比如 AutoAugment。
...
@@ -198,6 +198,18 @@ pipeline = [
...
@@ -198,6 +198,18 @@ pipeline = [
]
]
```
```
随机执行包装(
`RandomApply`
)用于以指定概率随机执行数据变换组合。例如:
```
python
pipeline
=
[
...
dict
(
type
=
'RandomApply'
,
transforms
=
[
dict
(
type
=
'Rotate'
,
angle
=
30.
)],
prob
=
0.3
)
# 以 0.3 的概率执行被包装的数据变换
...
]
```
### 多目标扩展(TransformBroadcaster)
### 多目标扩展(TransformBroadcaster)
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用
`KeyMapper`
来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(
`TransformBroadcaster`
)。
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用
`KeyMapper`
来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(
`TransformBroadcaster`
)。
...
...
mmcv/transforms/__init__.py
View file @
ff0dfb74
...
@@ -5,7 +5,8 @@ from .loading import LoadAnnotations, LoadImageFromFile
...
@@ -5,7 +5,8 @@ from .loading import LoadAnnotations, LoadImageFromFile
from
.processing
import
(
CenterCrop
,
MultiScaleFlipAug
,
Normalize
,
Pad
,
from
.processing
import
(
CenterCrop
,
MultiScaleFlipAug
,
Normalize
,
Pad
,
RandomChoiceResize
,
RandomFlip
,
RandomGrayscale
,
RandomChoiceResize
,
RandomFlip
,
RandomGrayscale
,
RandomResize
,
Resize
)
RandomResize
,
Resize
)
from
.wrappers
import
Compose
,
KeyMapper
,
RandomChoice
,
TransformBroadcaster
from
.wrappers
import
(
Compose
,
KeyMapper
,
RandomApply
,
RandomChoice
,
TransformBroadcaster
)
try
:
try
:
import
torch
# noqa: F401
import
torch
# noqa: F401
...
@@ -14,7 +15,8 @@ except ImportError:
...
@@ -14,7 +15,8 @@ except ImportError:
'BaseTransform'
,
'TRANSFORMS'
,
'TransformBroadcaster'
,
'Compose'
,
'BaseTransform'
,
'TRANSFORMS'
,
'TransformBroadcaster'
,
'Compose'
,
'RandomChoice'
,
'KeyMapper'
,
'LoadImageFromFile'
,
'LoadAnnotations'
,
'RandomChoice'
,
'KeyMapper'
,
'LoadImageFromFile'
,
'LoadAnnotations'
,
'Normalize'
,
'Resize'
,
'Pad'
,
'RandomFlip'
,
'RandomChoiceResize'
,
'Normalize'
,
'Resize'
,
'Pad'
,
'RandomFlip'
,
'RandomChoiceResize'
,
'CenterCrop'
,
'RandomGrayscale'
,
'MultiScaleFlipAug'
,
'RandomResize'
'CenterCrop'
,
'RandomGrayscale'
,
'MultiScaleFlipAug'
,
'RandomResize'
,
'RandomApply'
]
]
else
:
else
:
from
.formatting
import
ImageToTensor
,
ToTensor
,
to_tensor
from
.formatting
import
ImageToTensor
,
ToTensor
,
to_tensor
...
@@ -24,5 +26,5 @@ else:
...
@@ -24,5 +26,5 @@ else:
'RandomChoice'
,
'KeyMapper'
,
'LoadImageFromFile'
,
'LoadAnnotations'
,
'RandomChoice'
,
'KeyMapper'
,
'LoadImageFromFile'
,
'LoadAnnotations'
,
'Normalize'
,
'Resize'
,
'Pad'
,
'ToTensor'
,
'to_tensor'
,
'ImageToTensor'
,
'Normalize'
,
'Resize'
,
'Pad'
,
'ToTensor'
,
'to_tensor'
,
'ImageToTensor'
,
'RandomFlip'
,
'RandomChoiceResize'
,
'CenterCrop'
,
'RandomGrayscale'
,
'RandomFlip'
,
'RandomChoiceResize'
,
'CenterCrop'
,
'RandomGrayscale'
,
'MultiScaleFlipAug'
,
'RandomResize'
'MultiScaleFlipAug'
,
'RandomResize'
,
'RandomApply'
]
]
mmcv/transforms/wrappers.py
View file @
ff0dfb74
...
@@ -50,6 +50,8 @@ class Compose(BaseTransform):
...
@@ -50,6 +50,8 @@ class Compose(BaseTransform):
"""
"""
def
__init__
(
self
,
transforms
:
Union
[
Transform
,
List
[
Transform
]]):
def
__init__
(
self
,
transforms
:
Union
[
Transform
,
List
[
Transform
]]):
super
().
__init__
()
if
not
isinstance
(
transforms
,
list
):
if
not
isinstance
(
transforms
,
list
):
transforms
=
[
transforms
]
transforms
=
[
transforms
]
self
.
transforms
=
[]
self
.
transforms
=
[]
...
@@ -123,7 +125,7 @@ class KeyMapper(BaseTransform):
...
@@ -123,7 +125,7 @@ class KeyMapper(BaseTransform):
>>> # 'gt_img' to inner (used by inner transforms) filed name
>>> # 'gt_img' to inner (used by inner transforms) filed name
>>> # 'img'
>>> # 'img'
>>> dict(type='KeyMapper',
>>> dict(type='KeyMapper',
>>> mapping=
dict(img=
'gt_img'
)
,
>>> mapping=
{'img':
'gt_img'
}
,
>>> # auto_remap=True means output key mapping is the revert of
>>> # auto_remap=True means output key mapping is the revert of
>>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # back to outer 'gt_img'
>>> # back to outer 'gt_img'
...
@@ -158,6 +160,8 @@ class KeyMapper(BaseTransform):
...
@@ -158,6 +160,8 @@ class KeyMapper(BaseTransform):
auto_remap
:
Optional
[
bool
]
=
None
,
auto_remap
:
Optional
[
bool
]
=
None
,
allow_nonexist_keys
:
bool
=
False
):
allow_nonexist_keys
:
bool
=
False
):
super
().
__init__
()
self
.
allow_nonexist_keys
=
allow_nonexist_keys
self
.
allow_nonexist_keys
=
allow_nonexist_keys
self
.
mapping
=
mapping
self
.
mapping
=
mapping
...
@@ -318,7 +322,7 @@ class TransformBroadcaster(KeyMapper):
...
@@ -318,7 +322,7 @@ class TransformBroadcaster(KeyMapper):
>>> # respectively
>>> # respectively
>>> dict(type='TransformBroadcaster',
>>> dict(type='TransformBroadcaster',
>>> # case 1: from multiple outer fields
>>> # case 1: from multiple outer fields
>>> mapping=
dict(img=
['lq', 'gt']
)
,
>>> mapping=
{'img':
['lq', 'gt']
}
,
>>> auto_remap=True,
>>> auto_remap=True,
>>> # share_random_param=True means using identical random
>>> # share_random_param=True means using identical random
>>> # parameters in every processing
>>> # parameters in every processing
...
@@ -338,7 +342,7 @@ class TransformBroadcaster(KeyMapper):
...
@@ -338,7 +342,7 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='TransformBroadcaster',
>>> dict(type='TransformBroadcaster',
>>> # case 2: from one outer field that contains multiple
>>> # case 2: from one outer field that contains multiple
>>> # data elements (e.g. a list)
>>> # data elements (e.g. a list)
>>> # mapping=
dict(img=
'images'
)
,
>>> # mapping=
{'img':
'images'
}
,
>>> auto_remap=True,
>>> auto_remap=True,
>>> share_random_param=True,
>>> share_random_param=True,
>>> transforms=[
>>> transforms=[
...
@@ -420,10 +424,10 @@ class TransformBroadcaster(KeyMapper):
...
@@ -420,10 +424,10 @@ class TransformBroadcaster(KeyMapper):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
RandomChoice
(
BaseTransform
):
class
RandomChoice
(
BaseTransform
):
"""Process data with a randomly chosen
pipeline
from given candidates.
"""Process data with a randomly chosen
transform
from given candidates.
Args:
Args:
transforms (list[list]): A list of
pipeline
candidates, each is a
transforms (list[list]): A list of
transform
candidates, each is a
sequence of transforms.
sequence of transforms.
prob (list[float], optional): The probabilities associated
prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
with each pipeline. The length should be equal to the pipeline
...
@@ -446,6 +450,8 @@ class RandomChoice(BaseTransform):
...
@@ -446,6 +450,8 @@ class RandomChoice(BaseTransform):
transforms
:
List
[
Union
[
Transform
,
List
[
Transform
]]],
transforms
:
List
[
Union
[
Transform
,
List
[
Transform
]]],
prob
:
Optional
[
List
[
float
]]
=
None
):
prob
:
Optional
[
List
[
float
]]
=
None
):
super
().
__init__
()
if
prob
is
not
None
:
if
prob
is
not
None
:
assert
mmcv
.
is_seq_of
(
prob
,
float
)
assert
mmcv
.
is_seq_of
(
prob
,
float
)
assert
len
(
transforms
)
==
len
(
prob
),
\
assert
len
(
transforms
)
==
len
(
prob
),
\
...
@@ -467,3 +473,42 @@ class RandomChoice(BaseTransform):
...
@@ -467,3 +473,42 @@ class RandomChoice(BaseTransform):
def
transform
(
self
,
results
):
def
transform
(
self
,
results
):
idx
=
self
.
random_pipeline_index
()
idx
=
self
.
random_pipeline_index
()
return
self
.
transforms
[
idx
](
results
)
return
self
.
transforms
[
idx
](
results
)
@
TRANSFORMS
.
register_module
()
class
RandomApply
(
BaseTransform
):
"""Apply transforms randomly with a given probability.
Args:
transforms (list[dict | callable]): The transform or transform list
to randomly apply.
prob (float): The probability to apply transforms. Default: 0.5
Examples:
>>> # config
>>> pipeline = [
>>> dict(type='RandomApply',
>>> transforms=[dict(type='HorizontalFlip')],
>>> prob=0.3)
>>> ]
"""
def
__init__
(
self
,
transforms
:
Union
[
Transform
,
List
[
Transform
]],
prob
:
float
=
0.5
):
super
().
__init__
()
self
.
prob
=
prob
self
.
transforms
=
Compose
(
transforms
)
def
__iter__
(
self
):
return
iter
(
self
.
transforms
)
@
cache_randomness
def
random_apply
(
self
):
return
np
.
random
.
rand
()
<
self
.
prob
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
if
self
.
random_apply
():
results
=
self
.
transforms
(
results
)
return
results
tests/test_transforms/test_transforms_wrapper.py
View file @
ff0dfb74
...
@@ -7,8 +7,8 @@ import pytest
...
@@ -7,8 +7,8 @@ import pytest
from
mmcv.transforms.base
import
BaseTransform
from
mmcv.transforms.base
import
BaseTransform
from
mmcv.transforms.builder
import
TRANSFORMS
from
mmcv.transforms.builder
import
TRANSFORMS
from
mmcv.transforms.utils
import
cache_random_params
,
cache_randomness
from
mmcv.transforms.utils
import
cache_random_params
,
cache_randomness
from
mmcv.transforms.wrappers
import
(
Compose
,
KeyMapper
,
Random
Choice
,
from
mmcv.transforms.wrappers
import
(
Compose
,
KeyMapper
,
Random
Apply
,
TransformBroadcaster
)
RandomChoice
,
TransformBroadcaster
)
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
...
@@ -136,13 +136,13 @@ def test_cache_random_parameters():
...
@@ -136,13 +136,13 @@ def test_cache_random_parameters():
np
.
testing
.
assert_equal
(
results_1
[
'value'
],
results_2
[
'value'
])
np
.
testing
.
assert_equal
(
results_1
[
'value'
],
results_2
[
'value'
])
def
test_
apply_to
_mappe
d
():
def
test_
key
_mappe
r
():
# Case 1: simple remap
# Case 1: simple remap
pipeline
=
KeyMapper
(
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
1
)],
transforms
=
[
AddToValue
(
addend
=
1
)],
mapping
=
dict
(
value
=
'v_in'
)
,
mapping
=
{
'
value
'
:
'v_in'
}
,
remapping
=
dict
(
value
=
'v_out'
)
)
remapping
=
{
'
value
'
:
'v_out'
}
)
results
=
dict
(
value
=
0
,
v_in
=
1
)
results
=
dict
(
value
=
0
,
v_in
=
1
)
results
=
pipeline
(
results
)
results
=
pipeline
(
results
)
...
@@ -154,8 +154,8 @@ def test_apply_to_mapped():
...
@@ -154,8 +154,8 @@ def test_apply_to_mapped():
# Case 2: collecting list
# Case 2: collecting list
pipeline
=
KeyMapper
(
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
2
)],
transforms
=
[
AddToValue
(
addend
=
2
)],
mapping
=
dict
(
value
=
[
'v_in_1'
,
'v_in_2'
]
)
,
mapping
=
{
'
value
'
:
[
'v_in_1'
,
'v_in_2'
]
}
,
remapping
=
dict
(
value
=
[
'v_out_1'
,
'v_out_2'
]
)
)
remapping
=
{
'
value
'
:
[
'v_out_1'
,
'v_out_2'
]
}
)
results
=
dict
(
value
=
0
,
v_in_1
=
1
,
v_in_2
=
2
)
results
=
dict
(
value
=
0
,
v_in_1
=
1
,
v_in_2
=
2
)
with
pytest
.
warns
(
UserWarning
,
match
=
'value is a list'
):
with
pytest
.
warns
(
UserWarning
,
match
=
'value is a list'
):
...
@@ -170,8 +170,14 @@ def test_apply_to_mapped():
...
@@ -170,8 +170,14 @@ def test_apply_to_mapped():
# Case 3: collecting dict
# Case 3: collecting dict
pipeline
=
KeyMapper
(
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
2
)],
transforms
=
[
AddToValue
(
addend
=
2
)],
mapping
=
dict
(
value
=
dict
(
v1
=
'v_in_1'
,
v2
=
'v_in_2'
)),
mapping
=
{
'value'
:
{
remapping
=
dict
(
value
=
dict
(
v1
=
'v_out_1'
,
v2
=
'v_out_2'
)))
'v1'
:
'v_in_1'
,
'v2'
:
'v_in_2'
}},
remapping
=
{
'value'
:
{
'v1'
:
'v_out_1'
,
'v2'
:
'v_out_2'
}})
results
=
dict
(
value
=
0
,
v_in_1
=
1
,
v_in_2
=
2
)
results
=
dict
(
value
=
0
,
v_in_1
=
1
,
v_in_2
=
2
)
with
pytest
.
warns
(
UserWarning
,
match
=
'value is a dict'
):
with
pytest
.
warns
(
UserWarning
,
match
=
'value is a dict'
):
...
@@ -332,7 +338,7 @@ def test_apply_to_multiple():
...
@@ -332,7 +338,7 @@ def test_apply_to_multiple():
_
=
str
(
pipeline
)
_
=
str
(
pipeline
)
def
test_randomchoice
():
def
test_random
_
choice
():
# Case 1: given probability
# Case 1: given probability
pipeline
=
RandomChoice
(
pipeline
=
RandomChoice
(
...
@@ -355,7 +361,32 @@ def test_randomchoice():
...
@@ -355,7 +361,32 @@ def test_randomchoice():
transforms
=
[[
AddToValue
(
addend
=
1.0
)],
transforms
=
[[
AddToValue
(
addend
=
1.0
)],
[
AddToValue
(
addend
=
2.0
)]],
),
[
AddToValue
(
addend
=
2.0
)]],
),
],
],
mapping
=
dict
(
value
=
'values'
),
mapping
=
{
'value'
:
'values'
},
auto_remap
=
True
,
share_random_params
=
True
)
results
=
dict
(
values
=
[
0
for
_
in
range
(
10
)])
results
=
pipeline
(
results
)
# check share_random_params=True works so that all values are same
values
=
results
[
'values'
]
assert
all
(
map
(
lambda
x
:
x
==
values
[
0
],
values
))
def
test_random_apply
():
# Case 1: simple use
pipeline
=
RandomApply
(
transforms
=
[
AddToValue
(
addend
=
1.0
)],
prob
=
1.0
)
results
=
pipeline
(
dict
(
value
=
1
))
np
.
testing
.
assert_equal
(
results
[
'value'
],
2.0
)
pipeline
=
RandomApply
(
transforms
=
[
AddToValue
(
addend
=
1.0
)],
prob
=
0.0
)
results
=
pipeline
(
dict
(
value
=
1
))
np
.
testing
.
assert_equal
(
results
[
'value'
],
1.0
)
# Case 2: nested RandomApply in TransformBroadcaster
pipeline
=
TransformBroadcaster
(
transforms
=
[
RandomApply
(
transforms
=
[
AddToValue
(
addend
=
1
)],
prob
=
0.5
)],
mapping
=
{
'value'
:
'values'
},
auto_remap
=
True
,
auto_remap
=
True
,
share_random_params
=
True
)
share_random_params
=
True
)
...
@@ -365,6 +396,10 @@ def test_randomchoice():
...
@@ -365,6 +396,10 @@ def test_randomchoice():
values
=
results
[
'values'
]
values
=
results
[
'values'
]
assert
all
(
map
(
lambda
x
:
x
==
values
[
0
],
values
))
assert
all
(
map
(
lambda
x
:
x
==
values
[
0
],
values
))
# __iter__
for
_
in
pipeline
:
pass
def
test_utils
():
def
test_utils
():
# Test cache_randomness: normal case
# Test cache_randomness: normal case
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment