Commit 59eaefeb authored by liyining's avatar liyining Committed by zhouzaida
Browse files

[Feature] Support partial mapping by manually marking keys as ignored

parent 3b494a13
......@@ -160,7 +160,10 @@ pipeline = [
pipeline = [
...
dict(type='KeyMapper',
mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
mapping={
'img': 'gt_img', # 将 "gt_img" 字段映射至 "img" 字段
'mask': ..., # 不使用原始数据中的 "mask" 字段。即对于被包装的数据变换,数据中不包含 "mask" 字段
},
auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
......@@ -237,6 +240,29 @@ pipeline = [
]
```
在多目标扩展的 `mapping` 设置中,我们同样可以使用 `...` 来忽略指定的原始字段。如以下例子中,被包裹的 `RandomCrop` 会对字段 `"img"` 中的图像进行裁剪,并且在字段 `"img_shape"` 存在时更新剪裁后的图像大小。如果我们希望同时对两个图像字段 `"lq"``"gt"` 进行相同的随机裁剪,但只更新一次 `"img_shape"` 字段,可以通过例子中的方式实现:
```python
pipeline = [
dict(type='TransformBroadcaster',
mapping={
'img': ['lq', 'gt'],
'img_shape': ['img_shape', ...],
},
# 在完成变换后,将 "img" 和 "img_shape" 字段重映射回原先的字段
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_params=True,
transforms=[
# `RandomCrop` 类中会操作 "img" 和 "img_shape" 字段。若 "img_shape" 空缺,
# 则只操作 "img"
dict(type='RandomCrop'),
])
]
```
2. 应用于一个字段的一组目标
假设我们需要将数据变换应用于 `"images"` 字段,该字段为一个图像组成的 list。
......
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict
from typing import Dict, Optional
class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Dict:
def __call__(self, results: Dict) -> Optional[Dict]:
return self.transform(results)
@abstractmethod
def transform(self, results: Dict) -> Dict:
def transform(self, results: Dict) -> Optional[Dict]:
"""The transform function. All subclass of BaseTransform should
override this method.
......
......@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing
# on data sample.
key2counter = defaultdict(int)
key2counter: dict = defaultdict(int)
def _add_invoke_counter(obj, method_name):
method = getattr(obj, method_name)
......@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Store the original method and init the counter
if hasattr(t, '_methods_with_randomness'):
setattr(t, 'transform', _add_invoke_checker(t, 'transform'))
for name in t._methods_with_randomness:
for name in getattr(t, '_methods_with_randomness'):
setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform):
......@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return
# Remove cache enabled flag
del t._cache_enabled
delattr(t, '_cache_enabled')
if hasattr(t, '_cache'):
del t._cache
delattr(t, '_cache')
# Restore the original method
if hasattr(t, '_methods_with_randomness'):
for name in t._methods_with_randomness:
for name in getattr(t, '_methods_with_randomness'):
key = f'{id(t)}.{name}'
setattr(t, name, key2method[key])
key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform])
def _apply(t: BaseTransform, func: Callable[[BaseTransform], None]):
def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
if isinstance(t, BaseTransform):
func(t)
if isinstance(t, Iterable):
......
......@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator for required but missing keys in results
NotInResults = object()
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
......@@ -23,7 +29,7 @@ try:
except ImportError:
from contextlib import contextmanager
@contextmanager
@contextmanager # type: ignore
def nullcontext(resource=None):
try:
yield resource
......@@ -54,7 +60,7 @@ class Compose(BaseTransform):
if not isinstance(transforms, list):
transforms = [transforms]
self.transforms = []
self.transforms: List = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
......@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2: Collect and structure multiple items
>>> pipeline = [
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
......@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform):
>>> img_tar='img2')),
>>> transforms=...)
>>> ]
>>> # Example 3: Manually set ignored keys by "..."
>>> pipeline = [
>>> ...
>>> dict(type='KeyMapper',
>>> mapping={
>>> # map outer key "gt_img" to inner key "img"
>>> 'img': 'gt_img',
>>> # ignore outer key "mask"
>>> 'mask': ...,
>>> },
>>> transforms=[
>>> dict(type='RandomFlip'),
>>> ])
>>> ...
>>> ]
"""
def __init__(self,
......@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform):
"""Allow easy iteration over the transform sequence."""
return iter(self.transforms)
def map_input(self, data: Dict, mapping: Dict) -> Dict[str, Any]:
def _map_input(self, data: Dict,
mapping: Optional[Dict]) -> Dict[str, Any]:
"""KeyMapper inputs for the wrapped transforms by gathering and
renaming data items according to the mapping.
Args:
data (dict): The original input data
mapping (dict): The input key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
mapping (dict, optional): The input key mapping. See the document
of ``mmcv.transforms.wrappers.KeyMapper`` for details. In
set None, return the input data directly.
Returns:
dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline.
"""
if mapping is None:
return data.copy()
def _map(data, m):
if isinstance(m, dict):
# m is a dict {inner_key:outer_key, ...}
......@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform):
# transforms.
return m.__class__(_map(data, e) for e in m)
# allow manually mark a key to be ignored by ...
if m is ...:
return IgnoreKey
# m is an outer_key
if self.allow_nonexist_keys:
return data.get(m, NotInResults)
return data.get(m, IgnoreKey)
else:
return data.get(m)
collected = _map(data, mapping)
collected = {
k: v
for k, v in collected.items() if v is not NotInResults
}
# Retain unmapped items
inputs = data.copy()
......@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform):
return inputs
def map_output(self, data: Dict, remapping: Dict) -> Dict[str, Any]:
def _map_output(self, data: Dict,
remapping: Optional[Dict]) -> Dict[str, Any]:
"""KeyMapper outputs from the wrapped transforms by gathering and
renaming data items according to the remapping.
Args:
data (dict): The output of the wrapped pipeline.
remapping (dict): The output key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
remapping (dict, optional): The output key mapping. See the
document of ``mmcv.transforms.wrappers.KeyMapper`` for
details. If ``remapping is None``, no key mapping will be
applied but only remove the special token ``IgnoreKey``.
Returns:
dict: The output with remapped keys.
"""
# Remove ``IgnoreKey``
if remapping is None:
return {k: v for k, v in data.items() if v is not IgnoreKey}
def _map(data, m):
if isinstance(m, dict):
assert isinstance(data, dict)
......@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform):
results.update(_map(d_i, m_i))
return results
if m is IgnoreKey:
return {}
return {m: data}
# Note that unmapped items are not retained, which is different from
# the behavior in map_input. This is to avoid original data items
# the behavior in _map_input. This is to avoid original data items
# being overwritten by intermediate namesakes
return _map(data, remapping)
def transform(self, results: Dict) -> Dict:
inputs = results
if self.mapping:
inputs = self.map_input(inputs, self.mapping)
def _apply_transforms(self, inputs: Dict) -> Dict:
"""Apply ``self.transforms``.
Note that the special token ``IgnoreKey`` will be invisible to
``self.transforms``, but not removed in this method. It will be
eventually removed in :func:``self._map_output``.
"""
results = inputs.copy()
inputs = {k: v for k, v in inputs.items() if v is not IgnoreKey}
outputs = self.transforms(inputs)
if self.remapping:
outputs = self.map_output(outputs, self.remapping)
if outputs is None:
raise ValueError(
f'Transforms wrapped by {self.__class__.__name__} should '
'not return None.')
results.update(outputs) # type: ignore
return results
def transform(self, results: Dict) -> Dict:
"""Apply mapping, wrapped transforms and remapping."""
# Apply mapping
inputs = self._map_input(results, self.mapping)
# Apply wrapped transforms
outputs = self._apply_transforms(inputs)
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs)
return results
......@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper):
example.
Examples:
>>> # Example 1:
>>> # Example 1: Broadcast to enumerated keys, each contains a single
>>> # data element
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
......@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> # Example 2:
>>> # Example 2: Broadcast to keys that contains data sequences
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
......@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'),
>>> ])
>>> ]
>>> Example 3: Set ignored keys in broadcasting
>>> pipeline = [
>>> dict(type='TransformBroadcaster',
>>> # Broadcast the wrapped transforms to multiple images
>>> # 'lq' and 'gt, but only update 'img_shape' once
>>> mapping={
>>> 'img': ['lq', 'gt'],
>>> 'img_shape': ['img_shape', ...],
>>> },
>>> auto_remap=True,
>>> share_random_params=True,
>>> transforms=[
>>> # `RandomCrop` will modify the field "img",
>>> # and optionally update "img_shape" if it exists
>>> dict(type='RandomCrop'),
>>> ])
>>> ]
"""
def __init__(self,
......@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper):
self.share_random_params = share_random_params
def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
# infer split number from input
seq_len = None
seq_len = 0
key_rep = None
if self.mapping:
keys = self.mapping.keys()
else:
keys = data.keys()
for key in keys:
assert isinstance(data[key], Sequence)
if seq_len is not None:
if seq_len:
if len(data[key]) != seq_len:
raise ValueError('Got inconsistent sequence length: '
f'{seq_len} ({key_rep}) vs. '
......@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper):
seq_len = len(data[key])
key_rep = key
assert seq_len > 0, 'Fail to get the number of broadcasting targets'
scatters = []
for i in range(seq_len):
scatter = data.copy()
......@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper):
return scatters
def transform(self, results: Dict):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping
inputs = results
if self.mapping:
inputs = self.map_input(inputs, self.mapping)
inputs = self._map_input(results, self.mapping)
# Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs)
input_scatters = self.scatter_sequence(inputs)
# Control random parameter sharing with a context manager
if self.share_random_params:
......@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper):
# by all data items.
ctx = cache_random_params
else:
ctx = nullcontext
ctx = nullcontext # type: ignore
with ctx(self.transforms):
outputs = [self.transforms(_input) for _input in inputs]
output_scatters = [
self._apply_transforms(_input) for _input in input_scatters
]
# Collate output scatters (list of dict to dict of list)
outputs = {
key: [_output[key] for _output in outputs]
for key in outputs[0]
key: [_output[key] for _output in output_scatters]
for key in output_scatters[0]
}
# Apply output remapping
if self.remapping:
outputs = self.map_output(outputs, self.remapping)
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs)
return results
......@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform):
return iter(self.transforms)
@cache_randomness
def random_pipeline_index(self):
def random_pipeline_index(self) -> int:
"""Return a random transform index."""
indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.prob)
def transform(self, results):
def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly choose a transform to apply."""
idx = self.random_pipeline_index()
return self.transforms[idx](results)
......@@ -512,10 +601,14 @@ class RandomApply(BaseTransform):
return iter(self.transforms)
@cache_randomness
def random_apply(self):
def random_apply(self) -> bool:
"""Return a random bool value indicating whether apply the transform.
"""
return np.random.rand() < self.prob
def transform(self, results: Dict) -> Dict:
def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly apply the transform."""
if self.random_apply():
results = self.transforms(results)
return self.transforms(results)
else:
return results
......@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform):
def transform(self, results):
if 'num_1' in results and 'num_2' in results:
results['sum'] = results['num_1'] + results['num_2']
elif 'num_1' in results:
results['sum'] = results['num_1']
elif 'num_2' in results:
results['sum'] = results['num_2']
else:
results['sum'] = np.nan
return results
......@@ -262,7 +266,7 @@ def test_key_mapper():
np.testing.assert_equal(results['sum'], 3)
results = pipeline(dict(a=1))
assert np.isnan(results['sum'])
np.testing.assert_equal(results['sum'], 1)
# Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
......@@ -270,6 +274,17 @@ def test_key_mapper():
# note that the original key 'a' will not be removed
assert results == dict(a=1, b=1)
# Case 10: manually set keys ignored
pipeline = KeyMapper(
transforms=[SumTwoValues()],
mapping=dict(num_1='a', num_2=...), # num_2 (b) will be ignored
auto_remap=False,
# allow_nonexist_keys will not affect manually ignored keys
allow_nonexist_keys=False)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 1)
# Test basic functions
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
......@@ -353,6 +368,31 @@ def test_transform_broadcaster():
np.testing.assert_equal(results['values'][0], results['values'][1])
# Case 6: partial broadcasting
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 3)
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', ...]),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
assert 'b' not in results
# Test repr
_ = str(pipeline)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment