Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
59eaefeb
"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "a5dd01bb74d0d5775b6af177a8d077f7fb634947"
Commit
59eaefeb
authored
May 20, 2022
by
liyining
Committed by
zhouzaida
Jul 19, 2022
Browse files
[Feature] Support partial mapping by manually marking keys as ignored
parent
3b494a13
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
214 additions
and
54 deletions
+214
-54
docs/zh_cn/understand_mmcv/data_transform.md
docs/zh_cn/understand_mmcv/data_transform.md
+27
-1
mmcv/transforms/base.py
mmcv/transforms/base.py
+3
-3
mmcv/transforms/utils.py
mmcv/transforms/utils.py
+7
-6
mmcv/transforms/wrappers.py
mmcv/transforms/wrappers.py
+136
-43
tests/test_transforms/test_transforms_wrapper.py
tests/test_transforms/test_transforms_wrapper.py
+41
-1
No files found.
docs/zh_cn/understand_mmcv/data_transform.md
View file @
59eaefeb
...
@@ -160,7 +160,10 @@ pipeline = [
...
@@ -160,7 +160,10 @@ pipeline = [
pipeline
=
[
pipeline
=
[
...
...
dict
(
type
=
'KeyMapper'
,
dict
(
type
=
'KeyMapper'
,
mapping
=
{
'img'
:
'gt_img'
},
# 将 "gt_img" 字段映射至 "img" 字段
mapping
=
{
'img'
:
'gt_img'
,
# 将 "gt_img" 字段映射至 "img" 字段
'mask'
:
...,
# 不使用原始数据中的 "mask" 字段。即对于被包装的数据变换,数据中不包含 "mask" 字段
},
auto_remap
=
True
,
# 在完成变换后,将 "img" 重映射回 "gt_img" 字段
auto_remap
=
True
,
# 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms
=
[
transforms
=
[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
...
@@ -237,6 +240,29 @@ pipeline = [
...
@@ -237,6 +240,29 @@ pipeline = [
]
]
```
```
在多目标扩展的
`mapping`
设置中,我们同样可以使用
`...`
来忽略指定的原始字段。如以下例子中,被包裹的
`RandomCrop`
会对字段
`"img"`
中的图像进行裁剪,并且在字段
`"img_shape"`
存在时更新剪裁后的图像大小。如果我们希望同时对两个图像字段
`"lq"`
和
`"gt"`
进行相同的随机裁剪,但只更新一次
`"img_shape"`
字段,可以通过例子中的方式实现:
```
python
pipeline
=
[
dict
(
type
=
'TransformBroadcaster'
,
mapping
=
{
'img'
:
[
'lq'
,
'gt'
],
'img_shape'
:
[
'img_shape'
,
...],
},
# 在完成变换后,将 "img" 和 "img_shape" 字段重映射回原先的字段
auto_remap
=
True
,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_params
=
True
,
transforms
=
[
# `RandomCrop` 类中会操作 "img" 和 "img_shape" 字段。若 "img_shape" 空缺,
# 则只操作 "img"
dict
(
type
=
'RandomCrop'
),
])
]
```
2.
应用于一个字段的一组目标
2.
应用于一个字段的一组目标
假设我们需要将数据变换应用于
`"images"`
字段,该字段为一个图像组成的 list。
假设我们需要将数据变换应用于
`"images"`
字段,该字段为一个图像组成的 list。
...
...
mmcv/transforms/base.py
View file @
59eaefeb
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
Dict
from
typing
import
Dict
,
Optional
class
BaseTransform
(
metaclass
=
ABCMeta
):
class
BaseTransform
(
metaclass
=
ABCMeta
):
def
__call__
(
self
,
results
:
Dict
)
->
Dict
:
def
__call__
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]
:
return
self
.
transform
(
results
)
return
self
.
transform
(
results
)
@
abstractmethod
@
abstractmethod
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]
:
"""The transform function. All subclass of BaseTransform should
"""The transform function. All subclass of BaseTransform should
override this method.
override this method.
...
...
mmcv/transforms/utils.py
View file @
59eaefeb
...
@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
...
@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# key2counter stores the usage number of each cache_randomness. This is
# key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing
# used to check that any cache_randomness is invoked once during processing
# on data sample.
# on data sample.
key2counter
=
defaultdict
(
int
)
key2counter
:
dict
=
defaultdict
(
int
)
def
_add_invoke_counter
(
obj
,
method_name
):
def
_add_invoke_counter
(
obj
,
method_name
):
method
=
getattr
(
obj
,
method_name
)
method
=
getattr
(
obj
,
method_name
)
...
@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
...
@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Store the original method and init the counter
# Store the original method and init the counter
if
hasattr
(
t
,
'_methods_with_randomness'
):
if
hasattr
(
t
,
'_methods_with_randomness'
):
setattr
(
t
,
'transform'
,
_add_invoke_checker
(
t
,
'transform'
))
setattr
(
t
,
'transform'
,
_add_invoke_checker
(
t
,
'transform'
))
for
name
in
t
.
_methods_with_randomness
:
for
name
in
getattr
(
t
,
'
_methods_with_randomness
'
)
:
setattr
(
t
,
name
,
_add_invoke_counter
(
t
,
name
))
setattr
(
t
,
name
,
_add_invoke_counter
(
t
,
name
))
def
_end_cache
(
t
:
BaseTransform
):
def
_end_cache
(
t
:
BaseTransform
):
...
@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
...
@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return
return
# Remove cache enabled flag
# Remove cache enabled flag
del
t
.
_cache_enabled
del
attr
(
t
,
'
_cache_enabled
'
)
if
hasattr
(
t
,
'_cache'
):
if
hasattr
(
t
,
'_cache'
):
del
t
.
_cache
del
attr
(
t
,
'
_cache
'
)
# Restore the original method
# Restore the original method
if
hasattr
(
t
,
'_methods_with_randomness'
):
if
hasattr
(
t
,
'_methods_with_randomness'
):
for
name
in
t
.
_methods_with_randomness
:
for
name
in
getattr
(
t
,
'
_methods_with_randomness
'
)
:
key
=
f
'
{
id
(
t
)
}
.
{
name
}
'
key
=
f
'
{
id
(
t
)
}
.
{
name
}
'
setattr
(
t
,
name
,
key2method
[
key
])
setattr
(
t
,
name
,
key2method
[
key
])
key_transform
=
f
'
{
id
(
t
)
}
.transform'
key_transform
=
f
'
{
id
(
t
)
}
.transform'
setattr
(
t
,
'transform'
,
key2method
[
key_transform
])
setattr
(
t
,
'transform'
,
key2method
[
key_transform
])
def
_apply
(
t
:
BaseTransform
,
func
:
Callable
[[
BaseTransform
],
None
]):
def
_apply
(
t
:
Union
[
BaseTransform
,
Iterable
],
func
:
Callable
[[
BaseTransform
],
None
]):
if
isinstance
(
t
,
BaseTransform
):
if
isinstance
(
t
,
BaseTransform
):
func
(
t
)
func
(
t
)
if
isinstance
(
t
,
Iterable
):
if
isinstance
(
t
,
Iterable
):
...
...
mmcv/transforms/wrappers.py
View file @
59eaefeb
...
@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness
...
@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config
# Define type of transform or transform config
Transform
=
Union
[
Dict
,
Callable
[[
Dict
],
Dict
]]
Transform
=
Union
[
Dict
,
Callable
[[
Dict
],
Dict
]]
# Indicator for required but missing keys in results
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
NotInResults
=
object
()
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey
=
object
()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
# implementation.
...
@@ -23,7 +29,7 @@ try:
...
@@ -23,7 +29,7 @@ try:
except
ImportError
:
except
ImportError
:
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
@
contextmanager
@
contextmanager
# type: ignore
def
nullcontext
(
resource
=
None
):
def
nullcontext
(
resource
=
None
):
try
:
try
:
yield
resource
yield
resource
...
@@ -54,7 +60,7 @@ class Compose(BaseTransform):
...
@@ -54,7 +60,7 @@ class Compose(BaseTransform):
if
not
isinstance
(
transforms
,
list
):
if
not
isinstance
(
transforms
,
list
):
transforms
=
[
transforms
]
transforms
=
[
transforms
]
self
.
transforms
=
[]
self
.
transforms
:
List
=
[]
for
transform
in
transforms
:
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
if
isinstance
(
transform
,
dict
):
transform
=
TRANSFORMS
.
build
(
transform
)
transform
=
TRANSFORMS
.
build
(
transform
)
...
@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform):
...
@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform):
>>> dict(type='Normalize'),
>>> dict(type='Normalize'),
>>> ])
>>> ])
>>> ]
>>> ]
>>> # Example 2: Collect and structure multiple items
>>> # Example 2: Collect and structure multiple items
>>> pipeline = [
>>> pipeline = [
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
...
@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform):
...
@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform):
>>> img_tar='img2')),
>>> img_tar='img2')),
>>> transforms=...)
>>> transforms=...)
>>> ]
>>> ]
>>> # Example 3: Manually set ignored keys by "..."
>>> pipeline = [
>>> ...
>>> dict(type='KeyMapper',
>>> mapping={
>>> # map outer key "gt_img" to inner key "img"
>>> 'img': 'gt_img',
>>> # ignore outer key "mask"
>>> 'mask': ...,
>>> },
>>> transforms=[
>>> dict(type='RandomFlip'),
>>> ])
>>> ...
>>> ]
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform):
...
@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform):
"""Allow easy iteration over the transform sequence."""
"""Allow easy iteration over the transform sequence."""
return
iter
(
self
.
transforms
)
return
iter
(
self
.
transforms
)
def
map_input
(
self
,
data
:
Dict
,
mapping
:
Dict
)
->
Dict
[
str
,
Any
]:
def
_map_input
(
self
,
data
:
Dict
,
mapping
:
Optional
[
Dict
])
->
Dict
[
str
,
Any
]:
"""KeyMapper inputs for the wrapped transforms by gathering and
"""KeyMapper inputs for the wrapped transforms by gathering and
renaming data items according to the mapping.
renaming data items according to the mapping.
Args:
Args:
data (dict): The original input data
data (dict): The original input data
mapping (dict): The input key mapping. See the document of
mapping (dict, optional): The input key mapping. See the document
``mmcv.transforms.wrappers.KeyMapper`` for details.
of ``mmcv.transforms.wrappers.KeyMapper`` for details. In
set None, return the input data directly.
Returns:
Returns:
dict: The input data with remapped keys. This will be the actual
dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline.
input of the wrapped pipeline.
"""
"""
if
mapping
is
None
:
return
data
.
copy
()
def
_map
(
data
,
m
):
def
_map
(
data
,
m
):
if
isinstance
(
m
,
dict
):
if
isinstance
(
m
,
dict
):
# m is a dict {inner_key:outer_key, ...}
# m is a dict {inner_key:outer_key, ...}
...
@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform):
...
@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform):
# transforms.
# transforms.
return
m
.
__class__
(
_map
(
data
,
e
)
for
e
in
m
)
return
m
.
__class__
(
_map
(
data
,
e
)
for
e
in
m
)
# allow manually mark a key to be ignored by ...
if
m
is
...:
return
IgnoreKey
# m is an outer_key
# m is an outer_key
if
self
.
allow_nonexist_keys
:
if
self
.
allow_nonexist_keys
:
return
data
.
get
(
m
,
NotInResults
)
return
data
.
get
(
m
,
IgnoreKey
)
else
:
else
:
return
data
.
get
(
m
)
return
data
.
get
(
m
)
collected
=
_map
(
data
,
mapping
)
collected
=
_map
(
data
,
mapping
)
collected
=
{
k
:
v
for
k
,
v
in
collected
.
items
()
if
v
is
not
NotInResults
}
# Retain unmapped items
# Retain unmapped items
inputs
=
data
.
copy
()
inputs
=
data
.
copy
()
...
@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform):
...
@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform):
return
inputs
return
inputs
def
map_output
(
self
,
data
:
Dict
,
remapping
:
Dict
)
->
Dict
[
str
,
Any
]:
def
_map_output
(
self
,
data
:
Dict
,
remapping
:
Optional
[
Dict
])
->
Dict
[
str
,
Any
]:
"""KeyMapper outputs from the wrapped transforms by gathering and
"""KeyMapper outputs from the wrapped transforms by gathering and
renaming data items according to the remapping.
renaming data items according to the remapping.
Args:
Args:
data (dict): The output of the wrapped pipeline.
data (dict): The output of the wrapped pipeline.
remapping (dict): The output key mapping. See the document of
remapping (dict, optional): The output key mapping. See the
``mmcv.transforms.wrappers.KeyMapper`` for details.
document of ``mmcv.transforms.wrappers.KeyMapper`` for
details. If ``remapping is None``, no key mapping will be
applied but only remove the special token ``IgnoreKey``.
Returns:
Returns:
dict: The output with remapped keys.
dict: The output with remapped keys.
"""
"""
# Remove ``IgnoreKey``
if
remapping
is
None
:
return
{
k
:
v
for
k
,
v
in
data
.
items
()
if
v
is
not
IgnoreKey
}
def
_map
(
data
,
m
):
def
_map
(
data
,
m
):
if
isinstance
(
m
,
dict
):
if
isinstance
(
m
,
dict
):
assert
isinstance
(
data
,
dict
)
assert
isinstance
(
data
,
dict
)
...
@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform):
...
@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform):
results
.
update
(
_map
(
d_i
,
m_i
))
results
.
update
(
_map
(
d_i
,
m_i
))
return
results
return
results
if
m
is
IgnoreKey
:
return
{}
return
{
m
:
data
}
return
{
m
:
data
}
# Note that unmapped items are not retained, which is different from
# Note that unmapped items are not retained, which is different from
# the behavior in map_input. This is to avoid original data items
# the behavior in
_
map_input. This is to avoid original data items
# being overwritten by intermediate namesakes
# being overwritten by intermediate namesakes
return
_map
(
data
,
remapping
)
return
_map
(
data
,
remapping
)
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
def
_apply_transforms
(
self
,
inputs
:
Dict
)
->
Dict
:
inputs
=
results
"""Apply ``self.transforms``.
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
Note that the special token ``IgnoreKey`` will be invisible to
``self.transforms``, but not removed in this method. It will be
eventually removed in :func:``self._map_output``.
"""
results
=
inputs
.
copy
()
inputs
=
{
k
:
v
for
k
,
v
in
inputs
.
items
()
if
v
is
not
IgnoreKey
}
outputs
=
self
.
transforms
(
inputs
)
outputs
=
self
.
transforms
(
inputs
)
if
self
.
remapping
:
if
outputs
is
None
:
outputs
=
self
.
map_output
(
outputs
,
self
.
remapping
)
raise
ValueError
(
f
'Transforms wrapped by
{
self
.
__class__
.
__name__
}
should '
'not return None.'
)
results
.
update
(
outputs
)
# type: ignore
return
results
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
"""Apply mapping, wrapped transforms and remapping."""
# Apply mapping
inputs
=
self
.
_map_input
(
results
,
self
.
mapping
)
# Apply wrapped transforms
outputs
=
self
.
_apply_transforms
(
inputs
)
# Apply remapping
outputs
=
self
.
_map_output
(
outputs
,
self
.
remapping
)
results
.
update
(
outputs
)
results
.
update
(
outputs
)
return
results
return
results
...
@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper):
...
@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper):
example.
example.
Examples:
Examples:
>>> # Example 1:
>>> # Example 1: Broadcast to enumerated keys, each contains a single
>>> # data element
>>> pipeline = [
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...
@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper):
...
@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> dict(type='Normalize'),
>>> ])
>>> ])
>>> ]
>>> ]
>>> # Example 2:
>>> # Example 2: Broadcast to keys that contains data sequences
>>> pipeline = [
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...
@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper):
...
@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> dict(type='Normalize'),
>>> ])
>>> ])
>>> ]
>>> ]
>>> Example 3: Set ignored keys in broadcasting
>>> pipeline = [
>>> dict(type='TransformBroadcaster',
>>> # Broadcast the wrapped transforms to multiple images
>>> # 'lq' and 'gt, but only update 'img_shape' once
>>> mapping={
>>> 'img': ['lq', 'gt'],
>>> 'img_shape': ['img_shape', ...],
>>> },
>>> auto_remap=True,
>>> share_random_params=True,
>>> transforms=[
>>> # `RandomCrop` will modify the field "img",
>>> # and optionally update "img_shape" if it exists
>>> dict(type='RandomCrop'),
>>> ])
>>> ]
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper):
...
@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper):
self
.
share_random_params
=
share_random_params
self
.
share_random_params
=
share_random_params
def
scatter_sequence
(
self
,
data
:
Dict
)
->
List
[
Dict
]:
def
scatter_sequence
(
self
,
data
:
Dict
)
->
List
[
Dict
]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
# infer split number from input
# infer split number from input
seq_len
=
None
seq_len
=
0
key_rep
=
None
key_rep
=
None
if
self
.
mapping
:
if
self
.
mapping
:
keys
=
self
.
mapping
.
keys
()
keys
=
self
.
mapping
.
keys
()
else
:
else
:
keys
=
data
.
keys
()
keys
=
data
.
keys
()
for
key
in
keys
:
for
key
in
keys
:
assert
isinstance
(
data
[
key
],
Sequence
)
assert
isinstance
(
data
[
key
],
Sequence
)
if
seq_len
is
not
None
:
if
seq_len
:
if
len
(
data
[
key
])
!=
seq_len
:
if
len
(
data
[
key
])
!=
seq_len
:
raise
ValueError
(
'Got inconsistent sequence length: '
raise
ValueError
(
'Got inconsistent sequence length: '
f
'
{
seq_len
}
(
{
key_rep
}
) vs. '
f
'
{
seq_len
}
(
{
key_rep
}
) vs. '
...
@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper):
...
@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper):
seq_len
=
len
(
data
[
key
])
seq_len
=
len
(
data
[
key
])
key_rep
=
key
key_rep
=
key
assert
seq_len
>
0
,
'Fail to get the number of broadcasting targets'
scatters
=
[]
scatters
=
[]
for
i
in
range
(
seq_len
):
for
i
in
range
(
seq_len
):
scatter
=
data
.
copy
()
scatter
=
data
.
copy
()
...
@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper):
...
@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper):
return
scatters
return
scatters
def
transform
(
self
,
results
:
Dict
):
def
transform
(
self
,
results
:
Dict
):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping
# Apply input remapping
inputs
=
results
inputs
=
self
.
_map_input
(
results
,
self
.
mapping
)
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
# Scatter sequential inputs into a list
# Scatter sequential inputs into a list
inputs
=
self
.
scatter_sequence
(
inputs
)
input
_scatter
s
=
self
.
scatter_sequence
(
inputs
)
# Control random parameter sharing with a context manager
# Control random parameter sharing with a context manager
if
self
.
share_random_params
:
if
self
.
share_random_params
:
...
@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper):
...
@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper):
# by all data items.
# by all data items.
ctx
=
cache_random_params
ctx
=
cache_random_params
else
:
else
:
ctx
=
nullcontext
ctx
=
nullcontext
# type: ignore
with
ctx
(
self
.
transforms
):
with
ctx
(
self
.
transforms
):
outputs
=
[
self
.
transforms
(
_input
)
for
_input
in
inputs
]
output_scatters
=
[
self
.
_apply_transforms
(
_input
)
for
_input
in
input_scatters
]
# Collate output scatters (list of dict to dict of list)
# Collate output scatters (list of dict to dict of list)
outputs
=
{
outputs
=
{
key
:
[
_output
[
key
]
for
_output
in
outputs
]
key
:
[
_output
[
key
]
for
_output
in
output
_scatter
s
]
for
key
in
outputs
[
0
]
for
key
in
output
_scatter
s
[
0
]
}
}
# Apply output remapping
# Apply remapping
if
self
.
remapping
:
outputs
=
self
.
_map_output
(
outputs
,
self
.
remapping
)
outputs
=
self
.
map_output
(
outputs
,
self
.
remapping
)
results
.
update
(
outputs
)
results
.
update
(
outputs
)
return
results
return
results
...
@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform):
...
@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform):
return
iter
(
self
.
transforms
)
return
iter
(
self
.
transforms
)
@
cache_randomness
@
cache_randomness
def
random_pipeline_index
(
self
):
def
random_pipeline_index
(
self
)
->
int
:
"""Return a random transform index."""
indices
=
np
.
arange
(
len
(
self
.
transforms
))
indices
=
np
.
arange
(
len
(
self
.
transforms
))
return
np
.
random
.
choice
(
indices
,
p
=
self
.
prob
)
return
np
.
random
.
choice
(
indices
,
p
=
self
.
prob
)
def
transform
(
self
,
results
):
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
"""Randomly choose a transform to apply."""
idx
=
self
.
random_pipeline_index
()
idx
=
self
.
random_pipeline_index
()
return
self
.
transforms
[
idx
](
results
)
return
self
.
transforms
[
idx
](
results
)
...
@@ -512,10 +601,14 @@ class RandomApply(BaseTransform):
...
@@ -512,10 +601,14 @@ class RandomApply(BaseTransform):
return
iter
(
self
.
transforms
)
return
iter
(
self
.
transforms
)
@
cache_randomness
@
cache_randomness
def
random_apply
(
self
):
def
random_apply
(
self
)
->
bool
:
"""Return a random bool value indicating whether apply the transform.
"""
return
np
.
random
.
rand
()
<
self
.
prob
return
np
.
random
.
rand
()
<
self
.
prob
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
"""Randomly apply the transform."""
if
self
.
random_apply
():
if
self
.
random_apply
():
results
=
self
.
transforms
(
results
)
return
self
.
transforms
(
results
)
return
results
else
:
return
results
tests/test_transforms/test_transforms_wrapper.py
View file @
59eaefeb
...
@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform):
...
@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform):
def
transform
(
self
,
results
):
def
transform
(
self
,
results
):
if
'num_1'
in
results
and
'num_2'
in
results
:
if
'num_1'
in
results
and
'num_2'
in
results
:
results
[
'sum'
]
=
results
[
'num_1'
]
+
results
[
'num_2'
]
results
[
'sum'
]
=
results
[
'num_1'
]
+
results
[
'num_2'
]
elif
'num_1'
in
results
:
results
[
'sum'
]
=
results
[
'num_1'
]
elif
'num_2'
in
results
:
results
[
'sum'
]
=
results
[
'num_2'
]
else
:
else
:
results
[
'sum'
]
=
np
.
nan
results
[
'sum'
]
=
np
.
nan
return
results
return
results
...
@@ -262,7 +266,7 @@ def test_key_mapper():
...
@@ -262,7 +266,7 @@ def test_key_mapper():
np
.
testing
.
assert_equal
(
results
[
'sum'
],
3
)
np
.
testing
.
assert_equal
(
results
[
'sum'
],
3
)
results
=
pipeline
(
dict
(
a
=
1
))
results
=
pipeline
(
dict
(
a
=
1
))
assert
np
.
isnan
(
results
[
'sum'
])
np
.
testing
.
assert_equal
(
results
[
'sum'
]
,
1
)
# Case 9: use wrapper as a transform
# Case 9: use wrapper as a transform
transform
=
KeyMapper
(
mapping
=
dict
(
b
=
'a'
),
auto_remap
=
False
)
transform
=
KeyMapper
(
mapping
=
dict
(
b
=
'a'
),
auto_remap
=
False
)
...
@@ -270,6 +274,17 @@ def test_key_mapper():
...
@@ -270,6 +274,17 @@ def test_key_mapper():
# note that the original key 'a' will not be removed
# note that the original key 'a' will not be removed
assert
results
==
dict
(
a
=
1
,
b
=
1
)
assert
results
==
dict
(
a
=
1
,
b
=
1
)
# Case 10: manually set keys ignored
pipeline
=
KeyMapper
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
'a'
,
num_2
=
...),
# num_2 (b) will be ignored
auto_remap
=
False
,
# allow_nonexist_keys will not affect manually ignored keys
allow_nonexist_keys
=
False
)
results
=
pipeline
(
dict
(
a
=
1
,
b
=
2
))
np
.
testing
.
assert_equal
(
results
[
'sum'
],
1
)
# Test basic functions
# Test basic functions
pipeline
=
KeyMapper
(
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
1
)],
transforms
=
[
AddToValue
(
addend
=
1
)],
...
@@ -353,6 +368,31 @@ def test_transform_broadcaster():
...
@@ -353,6 +368,31 @@ def test_transform_broadcaster():
np
.
testing
.
assert_equal
(
results
[
'values'
][
0
],
results
[
'values'
][
1
])
np
.
testing
.
assert_equal
(
results
[
'values'
][
0
],
results
[
'values'
][
1
])
# Case 6: partial broadcasting
pipeline
=
TransformBroadcaster
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
[
'a_1'
,
'b_1'
],
num_2
=
[
'a_2'
,
...]),
remapping
=
dict
(
sum
=
[
'a'
,
'b'
]),
auto_remap
=
False
)
results
=
dict
(
a_1
=
1
,
a_2
=
2
,
b_1
=
3
,
b_2
=
4
)
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'a'
],
3
)
np
.
testing
.
assert_equal
(
results
[
'b'
],
3
)
pipeline
=
TransformBroadcaster
(
transforms
=
[
SumTwoValues
()],
mapping
=
dict
(
num_1
=
[
'a_1'
,
'b_1'
],
num_2
=
[
'a_2'
,
'b_2'
]),
remapping
=
dict
(
sum
=
[
'a'
,
...]),
auto_remap
=
False
)
results
=
dict
(
a_1
=
1
,
a_2
=
2
,
b_1
=
3
,
b_2
=
4
)
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'a'
],
3
)
assert
'b'
not
in
results
# Test repr
# Test repr
_
=
str
(
pipeline
)
_
=
str
(
pipeline
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment