Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
59eaefeb
Commit
59eaefeb
authored
May 20, 2022
by
liyining
Committed by
zhouzaida
Jul 19, 2022
Browse files
[Feature] Support partial mapping by manually marking keys as ignored
parent
3b494a13
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
54 deletions
+214
-54
docs/zh_cn/understand_mmcv/data_transform.md
docs/zh_cn/understand_mmcv/data_transform.md
+27
-1
mmcv/transforms/base.py
mmcv/transforms/base.py
+3
-3
mmcv/transforms/utils.py
mmcv/transforms/utils.py
+7
-6
mmcv/transforms/wrappers.py
mmcv/transforms/wrappers.py
+136
-43
tests/test_transforms/test_transforms_wrapper.py
tests/test_transforms/test_transforms_wrapper.py
+41
-1
No files found.
docs/zh_cn/understand_mmcv/data_transform.md
View file @
59eaefeb
...
...
@@ -160,7 +160,10 @@ pipeline = [
pipeline
=
[
...
dict
(
type
=
'KeyMapper'
,
mapping
=
{
'img'
:
'gt_img'
},
# 将 "gt_img" 字段映射至 "img" 字段
mapping
=
{
'img'
:
'gt_img'
,
# 将 "gt_img" 字段映射至 "img" 字段
'mask'
:
...,
# 不使用原始数据中的 "mask" 字段。即对于被包装的数据变换,数据中不包含 "mask" 字段
},
auto_remap
=
True
,
# 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms
=
[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
...
...
@@ -237,6 +240,29 @@ pipeline = [
]
```
在多目标扩展的
`mapping`
设置中,我们同样可以使用
`...`
来忽略指定的原始字段。如以下例子中,被包裹的
`RandomCrop`
会对字段
`"img"`
中的图像进行裁剪,并且在字段
`"img_shape"`
存在时更新剪裁后的图像大小。如果我们希望同时对两个图像字段
`"lq"`
和
`"gt"`
进行相同的随机裁剪,但只更新一次
`"img_shape"`
字段,可以通过例子中的方式实现:
```
python
pipeline
=
[
dict
(
type
=
'TransformBroadcaster'
,
mapping
=
{
'img'
:
[
'lq'
,
'gt'
],
'img_shape'
:
[
'img_shape'
,
...],
},
# 在完成变换后,将 "img" 和 "img_shape" 字段重映射回原先的字段
auto_remap
=
True
,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_params
=
True
,
transforms
=
[
# `RandomCrop` 类中会操作 "img" 和 "img_shape" 字段。若 "img_shape" 空缺,
# 则只操作 "img"
dict
(
type
=
'RandomCrop'
),
])
]
```
2.
应用于一个字段的一组目标
假设我们需要将数据变换应用于
`"images"`
字段,该字段为一个图像组成的 list。
...
...
mmcv/transforms/base.py
View file @
59eaefeb
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
Dict
from
typing
import
Dict
,
Optional
class
BaseTransform
(
metaclass
=
ABCMeta
):
def
__call__
(
self
,
results
:
Dict
)
->
Dict
:
def
__call__
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]
:
return
self
.
transform
(
results
)
@
abstractmethod
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]
:
"""The transform function. All subclass of BaseTransform should
override this method.
...
...
mmcv/transforms/utils.py
View file @
59eaefeb
...
...
@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing
# on data sample.
key2counter
=
defaultdict
(
int
)
key2counter
:
dict
=
defaultdict
(
int
)
def
_add_invoke_counter
(
obj
,
method_name
):
method
=
getattr
(
obj
,
method_name
)
...
...
@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Store the original method and init the counter
if
hasattr
(
t
,
'_methods_with_randomness'
):
setattr
(
t
,
'transform'
,
_add_invoke_checker
(
t
,
'transform'
))
for
name
in
t
.
_methods_with_randomness
:
for
name
in
getattr
(
t
,
'
_methods_with_randomness
'
)
:
setattr
(
t
,
name
,
_add_invoke_counter
(
t
,
name
))
def
_end_cache
(
t
:
BaseTransform
):
...
...
@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return
# Remove cache enabled flag
del
t
.
_cache_enabled
del
attr
(
t
,
'
_cache_enabled
'
)
if
hasattr
(
t
,
'_cache'
):
del
t
.
_cache
del
attr
(
t
,
'
_cache
'
)
# Restore the original method
if
hasattr
(
t
,
'_methods_with_randomness'
):
for
name
in
t
.
_methods_with_randomness
:
for
name
in
getattr
(
t
,
'
_methods_with_randomness
'
)
:
key
=
f
'
{
id
(
t
)
}
.
{
name
}
'
setattr
(
t
,
name
,
key2method
[
key
])
key_transform
=
f
'
{
id
(
t
)
}
.transform'
setattr
(
t
,
'transform'
,
key2method
[
key_transform
])
def
_apply
(
t
:
BaseTransform
,
func
:
Callable
[[
BaseTransform
],
None
]):
def
_apply
(
t
:
Union
[
BaseTransform
,
Iterable
],
func
:
Callable
[[
BaseTransform
],
None
]):
if
isinstance
(
t
,
BaseTransform
):
func
(
t
)
if
isinstance
(
t
,
Iterable
):
...
...
mmcv/transforms/wrappers.py
View file @
59eaefeb
...
...
@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config
Transform
=
Union
[
Dict
,
Callable
[[
Dict
],
Dict
]]
# Indicator for required but missing keys in results
NotInResults
=
object
()
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey
=
object
()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
...
...
@@ -23,7 +29,7 @@ try:
except
ImportError
:
from
contextlib
import
contextmanager
@
contextmanager
@
contextmanager
# type: ignore
def
nullcontext
(
resource
=
None
):
try
:
yield
resource
...
...
@@ -54,7 +60,7 @@ class Compose(BaseTransform):
if
not
isinstance
(
transforms
,
list
):
transforms
=
[
transforms
]
self
.
transforms
=
[]
self
.
transforms
:
List
=
[]
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
transform
=
TRANSFORMS
.
build
(
transform
)
...
...
@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2: Collect and structure multiple items
>>> pipeline = [
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
...
...
@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform):
>>> img_tar='img2')),
>>> transforms=...)
>>> ]
>>> # Example 3: Manually set ignored keys by "..."
>>> pipeline = [
>>> ...
>>> dict(type='KeyMapper',
>>> mapping={
>>> # map outer key "gt_img" to inner key "img"
>>> 'img': 'gt_img',
>>> # ignore outer key "mask"
>>> 'mask': ...,
>>> },
>>> transforms=[
>>> dict(type='RandomFlip'),
>>> ])
>>> ...
>>> ]
"""
def
__init__
(
self
,
...
...
@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform):
"""Allow easy iteration over the transform sequence."""
return
iter
(
self
.
transforms
)
def
map_input
(
self
,
data
:
Dict
,
mapping
:
Dict
)
->
Dict
[
str
,
Any
]:
def
_map_input
(
self
,
data
:
Dict
,
mapping
:
Optional
[
Dict
])
->
Dict
[
str
,
Any
]:
"""KeyMapper inputs for the wrapped transforms by gathering and
renaming data items according to the mapping.
Args:
data (dict): The original input data
mapping (dict): The input key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
mapping (dict, optional): The input key mapping. See the document
of ``mmcv.transforms.wrappers.KeyMapper`` for details. In
set None, return the input data directly.
Returns:
dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline.
"""
if
mapping
is
None
:
return
data
.
copy
()
def
_map
(
data
,
m
):
if
isinstance
(
m
,
dict
):
# m is a dict {inner_key:outer_key, ...}
...
...
@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform):
# transforms.
return
m
.
__class__
(
_map
(
data
,
e
)
for
e
in
m
)
# allow manually mark a key to be ignored by ...
if
m
is
...:
return
IgnoreKey
# m is an outer_key
if
self
.
allow_nonexist_keys
:
return
data
.
get
(
m
,
NotInResults
)
return
data
.
get
(
m
,
IgnoreKey
)
else
:
return
data
.
get
(
m
)
collected
=
_map
(
data
,
mapping
)
collected
=
{
k
:
v
for
k
,
v
in
collected
.
items
()
if
v
is
not
NotInResults
}
# Retain unmapped items
inputs
=
data
.
copy
()
...
...
@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform):
return
inputs
def
map_output
(
self
,
data
:
Dict
,
remapping
:
Dict
)
->
Dict
[
str
,
Any
]:
def
_map_output
(
self
,
data
:
Dict
,
remapping
:
Optional
[
Dict
])
->
Dict
[
str
,
Any
]:
"""KeyMapper outputs from the wrapped transforms by gathering and
renaming data items according to the remapping.
Args:
data (dict): The output of the wrapped pipeline.
remapping (dict): The output key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
remapping (dict, optional): The output key mapping. See the
document of ``mmcv.transforms.wrappers.KeyMapper`` for
details. If ``remapping is None``, no key mapping will be
applied but only remove the special token ``IgnoreKey``.
Returns:
dict: The output with remapped keys.
"""
# Remove ``IgnoreKey``
if
remapping
is
None
:
return
{
k
:
v
for
k
,
v
in
data
.
items
()
if
v
is
not
IgnoreKey
}
def
_map
(
data
,
m
):
if
isinstance
(
m
,
dict
):
assert
isinstance
(
data
,
dict
)
...
...
@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform):
results
.
update
(
_map
(
d_i
,
m_i
))
return
results
if
m
is
IgnoreKey
:
return
{}
return
{
m
:
data
}
# Note that unmapped items are not retained, which is different from
# the behavior in map_input. This is to avoid original data items
# the behavior in
_
map_input. This is to avoid original data items
# being overwritten by intermediate namesakes
return
_map
(
data
,
remapping
)
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
inputs
=
results
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
def
_apply_transforms
(
self
,
inputs
:
Dict
)
->
Dict
:
"""Apply ``self.transforms``.
Note that the special token ``IgnoreKey`` will be invisible to
``self.transforms``, but not removed in this method. It will be
eventually removed in :func:``self._map_output``.
"""
results
=
inputs
.
copy
()
inputs
=
{
k
:
v
for
k
,
v
in
inputs
.
items
()
if
v
is
not
IgnoreKey
}
outputs
=
self
.
transforms
(
inputs
)
if
self
.
remapping
:
outputs
=
self
.
map_output
(
outputs
,
self
.
remapping
)
if
outputs
is
None
:
raise
ValueError
(
f
'Transforms wrapped by
{
self
.
__class__
.
__name__
}
should '
'not return None.'
)
results
.
update
(
outputs
)
# type: ignore
return
results
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
"""Apply mapping, wrapped transforms and remapping."""
# Apply mapping
inputs
=
self
.
_map_input
(
results
,
self
.
mapping
)
# Apply wrapped transforms
outputs
=
self
.
_apply_transforms
(
inputs
)
# Apply remapping
outputs
=
self
.
_map_output
(
outputs
,
self
.
remapping
)
results
.
update
(
outputs
)
return
results
...
...
@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper):
example.
Examples:
>>> # Example 1:
>>> # Example 1: Broadcast to enumerated keys, each contains a single
>>> # data element
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...
...
@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2:
>>> # Example 2: Broadcast to keys that contains data sequences
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...
...
@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> Example 3: Set ignored keys in broadcasting
>>> pipeline = [
>>> dict(type='TransformBroadcaster',
>>> # Broadcast the wrapped transforms to multiple images
>>> # 'lq' and 'gt, but only update 'img_shape' once
>>> mapping={
>>> 'img': ['lq', 'gt'],
>>> 'img_shape': ['img_shape', ...],
>>> },
>>> auto_remap=True,
>>> share_random_params=True,
>>> transforms=[
>>> # `RandomCrop` will modify the field "img",
>>> # and optionally update "img_shape" if it exists
>>> dict(type='RandomCrop'),
>>> ])
>>> ]
"""
def
__init__
(
self
,
...
...
@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper):
self
.
share_random_params
=
share_random_params
def
scatter_sequence
(
self
,
data
:
Dict
)
->
List
[
Dict
]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
# infer split number from input
seq_len
=
None
seq_len
=
0
key_rep
=
None
if
self
.
mapping
:
keys
=
self
.
mapping
.
keys
()
else
:
keys
=
data
.
keys
()
for
key
in
keys
:
assert
isinstance
(
data
[
key
],
Sequence
)
if
seq_len
is
not
None
:
if
seq_len
:
if
len
(
data
[
key
])
!=
seq_len
:
raise
ValueError
(
'Got inconsistent sequence length: '
f
'
{
seq_len
}
(
{
key_rep
}
) vs. '
...
...
@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper):
seq_len
=
len
(
data
[
key
])
key_rep
=
key
assert
seq_len
>
0
,
'Fail to get the number of broadcasting targets'
scatters
=
[]
for
i
in
range
(
seq_len
):
scatter
=
data
.
copy
()
...
...
@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper):
return
scatters
def
transform
(
self
,
results
:
Dict
):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping
inputs
=
results
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
inputs
=
self
.
_map_input
(
results
,
self
.
mapping
)
# Scatter sequential inputs into a list
inputs
=
self
.
scatter_sequence
(
inputs
)
input
_scatter
s
=
self
.
scatter_sequence
(
inputs
)
# Control random parameter sharing with a context manager
if
self
.
share_random_params
:
...
...
@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper):
# by all data items.
ctx
=
cache_random_params
else
:
ctx
=
nullcontext
ctx
=
nullcontext
# type: ignore
with
ctx
(
self
.
transforms
):
outputs
=
[
self
.
transforms
(
_input
)
for
_input
in
inputs
]
output_scatters
=
[
self
.
_apply_transforms
(
_input
)
for
_input
in
input_scatters
]
# Collate output scatters (list of dict to dict of list)
outputs
=
{
key
:
[
_output
[
key
]
for
_output
in
outputs
]
for
key
in
outputs
[
0
]
key
:
[
_output
[
key
]
for
_output
in
output
_scatter
s
]
for
key
in
output
_scatter
s
[
0
]
}
# Apply output remapping
if
self
.
remapping
:
outputs
=
self
.
map_output
(
outputs
,
self
.
remapping
)
# Apply remapping
outputs
=
self
.
_map_output
(
outputs
,
self
.
remapping
)
results
.
update
(
outputs
)
return
results
...
...
@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform):
return
iter
(
self
.
transforms
)
@
cache_randomness
def
random_pipeline_index
(
self
):
def
random_pipeline_index
(
self
)
->
int
:
"""Return a random transform index."""
indices
=
np
.
arange
(
len
(
self
.
transforms
))
return
np
.
random
.
choice
(
indices
,
p
=
self
.
prob
)
def
transform
(
self
,
results
):
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
"""Randomly choose a transform to apply."""
idx
=
self
.
random_pipeline_index
()
return
self
.
transforms
[
idx
](
results
)
...
...
@@ -512,10 +601,14 @@ class RandomApply(BaseTransform):
return
iter
(
self
.
transforms
)
@
cache_randomness
def
random_apply
(
self
):
def
random_apply
(
self
)
->
bool
:
"""Return a random bool value indicating whether apply the transform.
"""
return
np
.
random
.
rand
()
<
self
.
prob
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
"""Randomly apply the transform."""
if
self
.
random_apply
():
results
=
self
.
transforms
(
results
)
return
self
.
transforms
(
results
)
else
:
return
results
tests/test_transforms/test_transforms_wrapper.py
View file @
59eaefeb
...
...
@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform):
def
transform
(
self
,
results
):
if
'num_1'
in
results
and
'num_2'
in
results
:
results
[
'sum'
]
=
results
[
'num_1'
]
+
results
[
'num_2'
]
elif
'num_1'
in
results
:
results
[
'sum'
]
=
results
[
'num_1'
]
elif
'num_2'
in
results
:
results
[
'sum'
]
=
results
[
'num_2'
]
else
:
results
[
'sum'
]
=
np
.
nan
return
results
...
...
@@ -262,7 +266,7 @@ def test_key_mapper():
np
.
testing
.
assert_equal
(
results
[
'sum'
],
3
)
results
=
pipeline
(
dict
(
a
=
1
))
assert
np
.
isnan
(
results
[
'sum'
])
np
.
testing
.
assert_equal
(
results
[
'sum'
]
,
1
)
# Case 9: use wrapper as a transform
transform
=
KeyMapper
(
mapping
=
dict
(
b
=
'a'
),
auto_remap
=
False
)
...
...
@@ -270,6 +274,17 @@ def test_key_mapper():
# note that the original key 'a' will not be removed
assert
results
==
dict
(
a
=
1
,
b
=
1
)
# Case 10: manually set keys ignored
pipeline
=
KeyMapper
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
'a'
,
num_2
=
...),
# num_2 (b) will be ignored
auto_remap
=
False
,
# allow_nonexist_keys will not affect manually ignored keys
allow_nonexist_keys
=
False
)
results
=
pipeline
(
dict
(
a
=
1
,
b
=
2
))
np
.
testing
.
assert_equal
(
results
[
'sum'
],
1
)
# Test basic functions
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
1
)],
...
...
@@ -353,6 +368,31 @@ def test_transform_broadcaster():
np
.
testing
.
assert_equal
(
results
[
'values'
][
0
],
results
[
'values'
][
1
])
# Case 6: partial broadcasting
pipeline
=
TransformBroadcaster
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
[
'a_1'
,
'b_1'
],
num_2
=
[
'a_2'
,
...]),
remapping
=
dict
(
sum
=
[
'a'
,
'b'
]),
auto_remap
=
False
)
results
=
dict
(
a_1
=
1
,
a_2
=
2
,
b_1
=
3
,
b_2
=
4
)
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'a'
],
3
)
np
.
testing
.
assert_equal
(
results
[
'b'
],
3
)
pipeline
=
TransformBroadcaster
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
[
'a_1'
,
'b_1'
],
num_2
=
[
'a_2'
,
'b_2'
]),
remapping
=
dict
(
sum
=
[
'a'
,
...]),
auto_remap
=
False
)
results
=
dict
(
a_1
=
1
,
a_2
=
2
,
b_1
=
3
,
b_2
=
4
)
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'a'
],
3
)
assert
'b'
not
in
results
# Test repr
_
=
str
(
pipeline
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment