Commit e2ca0733 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

Refactor base transform (#1830)

* rename cacheable_method to cache_randomness

* refactor transform wrappers and update docs

* fix all_nonexist_keys

* fix lint

* rename transform wrappers
parent 0a5b4125
......@@ -150,18 +150,18 @@ pipeline = [
变换包装是一种特殊的数据变换类,他们本身并不操作数据字典中的图像、标签等信息,而是对其中定义的数据变换的行为进行增强。
### 字段映射(Remap
### 字段映射(KeyMapper
字段映射包装(`Remap`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。
字段映射包装(`KeyMapper`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。
如果配合注册器和配置文件使用的话,在配置文件中数据集的 `pipeline` 中如下例使用字段映射包装:
```python
pipeline = [
...
dict(type='Remap',
input_mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
inplace=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
dict(type='KeyMapper',
mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'),
......@@ -182,7 +182,7 @@ pipeline = [
pipeline = [
...
dict(type='RandomChoice',
pipelines=[
transforms=[
[
dict(type='Posterize', bits=4),
dict(type='Rotate', angle=30.)
......@@ -192,17 +192,17 @@ pipeline = [
dict(type='Rotate', angle=30)
], # 第二种随机变换组合
],
pipeline_probs=[0.4, 0.6] # 两种随机变换组合各自的选用概率
prob=[0.4, 0.6] # 两种随机变换组合各自的选用概率
)
...
]
```
### 多目标扩展(ApplyToMultiple
### 多目标扩展(TransformBroadcaster
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `Remap` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`ApplyToMultiple`)。
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `KeyMapper` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`TransformBroadcaster`)。
多目标扩展包装(`ApplyToMultiple`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。
多目标扩展包装(`TransformBroadcaster`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。
1. 应用于多个字段
......@@ -210,11 +210,11 @@ pipeline = [
```python
pipeline = [
dict(type='ApplyToMultiple',
dict(type='TransformBroadcaster',
# 分别应用于 "lq" 和 "gt" 两个字段,并将二者应设置 "img" 字段
input_mapping={'img': ['lq', 'gt']},
mapping={'img': ['lq', 'gt']},
# 在完成变换后,将 "img" 字段重映射回原先的字段
inplace=True,
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_param=True,
......@@ -231,11 +231,11 @@ pipeline = [
```python
pipeline = [
dict(type='ApplyToMultiple',
dict(type='TransformBroadcaster',
# 将 "images" 字段下的每张图片映射至 "img" 字段
input_mapping={'img': 'images'},
mapping={'img': 'images'},
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
inplace=True,
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
share_random_param=True,
transforms=[
......@@ -245,12 +245,12 @@ pipeline = [
]
```
`ApplyToMultiple` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
`TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
```python
from mmcv.transforms.utils import cacheable_method
from mmcv.transforms.utils import cache_randomness
@TRANSFORMS.register_module()
class MyRandomFlip(BaseTransform):
......@@ -259,7 +259,7 @@ class MyRandomFlip(BaseTransform):
self.prob = prob
self.direction = direction
@cacheable_method # 标注该方法的输出为可共享的随机变量
@cache_randomness # 标注该方法的输出为可共享的随机变量
def do_flip(self):
flip = True if random.random() > self.prob else False
return flip
......@@ -271,4 +271,4 @@ class MyRandomFlip(BaseTransform):
return results
```
通过 `cacheable_method` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `ApplyToMultiple` 对多个目标的变换中,这一变量的值都会保持一致。
通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
......@@ -4,24 +4,24 @@ from .loading import LoadAnnotation, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomResize, Resize)
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
try:
import torch # noqa: F401
except ImportError:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip',
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
......@@ -11,17 +11,17 @@ from typing import Callable, Union
from .base import BaseTransform
class cacheable_method:
"""Decorator that marks a method of a transform class as a cacheable
method.
class cache_randomness:
"""Decorator that marks the method with random return value(s) in a
transform class.
This decorator is usually used together with the context-manager
:func`:cache_random_params`. In this context, a cacheable method will
:func`:cache_random_params`. In this context, a decorated method will
cache its return value(s) at the first time of being invoked, and always
return the cached values when being invoked again.
.. note::
Only a instance method can be decorated as a cacheable_method.
Only an instance method can be decorated by ``cache_randomness``.
"""
def __init__(self, func):
......@@ -29,12 +29,12 @@ class cacheable_method:
# Check `func` is to be bound as an instance method
if not inspect.isfunction(func):
raise TypeError('Unsupport callable to decorate with'
'@cacheable_method.')
'@cache_randomness.')
func_args = inspect.getfullargspec(func).args
if len(func_args) == 0 or func_args[0] != 'self':
raise TypeError(
'@cacheable_method should only be used to decorate '
'instance methods (the first argument is `self`).')
'@cache_randomness should only be used to decorate '
'instance methods (the first argument is ``self``).')
functools.update_wrapper(self, func)
self.func = func
......@@ -42,24 +42,24 @@ class cacheable_method:
def __set_name__(self, owner, name):
# Maintain a record of decorated methods in the class
if not hasattr(owner, '_cacheable_methods'):
setattr(owner, '_cacheable_methods', [])
owner._cacheable_methods.append(self.__name__)
if not hasattr(owner, '_methods_with_randomness'):
setattr(owner, '_methods_with_randomness', [])
owner._methods_with_randomness.append(self.__name__)
def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated
# by cacheable_method
# by cache_randomness
instance = self.instance_ref()
name = self.__name__
# Check the flag `self._cache_enabled`, which should be
# set by the contextmanagers like `cache_random_parameters`
# Check the flag ``self._cache_enabled``, which should be
# set by the contextmanagers like ``cache_random_parameters```
cache_enabled = getattr(instance, '_cache_enabled', False)
if cache_enabled:
# Initialize the cache of the transform instances. The flag
# `cache_enabled` is set by contextmanagers like
# `cache_random_params`.
# ``cache_enabled``` is set by contextmanagers like
# ``cache_random_params```.
if not hasattr(instance, '_cache'):
setattr(instance, '_cache', {})
......@@ -81,13 +81,13 @@ class cacheable_method:
@contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of cacheable methods in
transforms.
"""Context-manager that enables the cache of return values of methods
decorated by ``cache_randomness`` in transforms.
In this mode, cacheable methods will cache their return values on the
In this mode, decorated methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow
to apply random transforms in a deterministic way. For example, apply same
transforms on multiple examples. See `cacheable_method` for more
transforms on multiple examples. See ``cache_randomness`` for more
information.
Args:
......@@ -99,12 +99,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# ones. These methods will be restituted when exiting the context.
key2method = dict()
# key2counter stores the usage number of each cacheable_method. This is
# used to check that any cacheable_method is invoked once during processing
# key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing
# on data sample.
key2counter = defaultdict(int)
def _add_counter(obj, method_name):
def _add_invoke_counter(obj, method_name):
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
......@@ -116,15 +116,43 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return wrapped
def _add_invoke_checker(obj, method_name):
# check that the method in _methods_with_randomness has been
# invoked at most once
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
@functools.wraps(method)
def wrapped(*args, **kwargs):
# clear counter
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
key2counter[key] = 0
output = method(*args, **kwargs)
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
if key2counter[key] > 1:
raise RuntimeError(
'The method decorated by ``cache_randomness`` should '
'be invoked at most once during processing one data '
f'sample. The method {name} of {obj} has been invoked'
f' {key2counter[key]} times.')
return output
return wrapped
def _start_cache(t: BaseTransform):
# Set cache enabled flag
setattr(t, '_cache_enabled', True)
# Store the original method and init the counter
if hasattr(t, '_cacheable_methods'):
setattr(t, 'transform', _add_counter(t, 'transform'))
for name in t._cacheable_methods:
setattr(t, name, _add_counter(t, name))
if hasattr(t, '_methods_with_randomness'):
setattr(t, 'transform', _add_invoke_checker(t, 'transform'))
for name in t._methods_with_randomness:
setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform):
# Remove cache enabled flag
......@@ -133,17 +161,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
del t._cache
# Restore the original method
if hasattr(t, '_cacheable_methods'):
key_transform = f'{id(t)}.transform'
for name in t._cacheable_methods:
if hasattr(t, '_methods_with_randomness'):
for name in t._methods_with_randomness:
key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only '
f'once during processing one data sample. {t} got '
f'unmatched number of {key2counter[key]} ({name}) vs '
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key])
key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable],
......@@ -151,7 +174,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'):
if hasattr(t, '_methods_with_randomness'):
func(t)
if isinstance(t, Iterable):
for _t in t:
......
......@@ -8,7 +8,10 @@ import numpy as np
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
from .utils import cache_random_params, cacheable_method
from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator for required but missing keys in results
NotInResults = object()
......@@ -46,8 +49,9 @@ class Compose(BaseTransform):
>>> ]
"""
def __init__(self, transforms: List[Union[Dict, Callable[[Dict], Dict]]]):
assert isinstance(transforms, Sequence)
def __init__(self, transforms: Union[Transform, List[Transform]]):
if not isinstance(transforms, list):
transforms = [transforms]
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
......@@ -88,42 +92,42 @@ class Compose(BaseTransform):
@TRANSFORMS.register_module()
class Remap(BaseTransform):
"""A transform wrapper to remap and reorganize the input/output of the
class KeyMapper(BaseTransform):
"""A transform wrapper to map and reorganize the input/output of the
wrapped transforms (or sub-pipeline).
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping.
transforms (list[dict | callable], optional): Sequence of transform
object or config dict to be wrapped.
mapping (dict): A dict that defines the input key mapping.
The keys corresponds to the inner key (i.e., kwargs of the
`transform` method), and should be string type. The values
``transform`` method), and should be string type. The values
corresponds to the outer keys (i.e., the keys of the
data/results), and should have a type of string, list or dict.
None means not applying input mapping. Default: None.
output_mapping (dict): A dict that defines the output key mapping.
remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used
as the output_mapping. Note that if inplace is set True,
output_mapping should be None and strict should be True.
``mapping``. Default: None.
auto_remap (bool, optional): If True, an inverse of the mapping will
be used as the remapping. If auto_remap is not given, it will be
automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
Examples:
>>> # Example 1: Remap 'gt_img' to 'img'
>>> # Example 1: KeyMapper 'gt_img' to 'img'
>>> pipeline = [
>>> # Use Remap to convert outer (original) field name 'gt_img'
>>> # to inner (used by inner transforms) filed name 'img'
>>> dict(type='Remap',
>>> input_mapping=dict(img='gt_img'),
>>> # inplace=True means output key mapping is the revert of
>>> # Use KeyMapper to convert outer (original) field name
>>> # 'gt_img' to inner (used by inner transforms) filed name
>>> # 'img'
>>> dict(type='KeyMapper',
>>> mapping=dict(img='gt_img'),
>>> # auto_remap=True means output key mapping is the revert of
>>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # back to outer 'gt_img'
>>> inplace=True,
>>> auto_remap=True,
>>> transforms=[
>>> # In all transforms' implementation just use 'img'
>>> # as a standard field name
......@@ -136,10 +140,10 @@ class Remap(BaseTransform):
>>> # The inner field 'imgs' will be a dict with keys 'img_src'
>>> # and 'img_tar', whose values are outer fields 'img1' and
>>> # 'img2' respectively.
>>> dict(type='Remap',
>>> dict(type='KeyMapper',
>>> dict(
>>> type='Remap',
>>> input_mapping=dict(
>>> type='KeyMapper',
>>> mapping=dict(
>>> imgs=dict(
>>> img_src='img1',
>>> img_tar='img2')),
......@@ -148,66 +152,67 @@ class Remap(BaseTransform):
"""
def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
input_mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None,
inplace: bool = False,
strict: bool = True):
self.inplace = inplace
self.strict = strict
self.input_mapping = input_mapping
if self.inplace:
if not self.strict:
raise ValueError('Remap: `strict` must be set True if'
'`inplace` is set True.')
if output_mapping is not None:
raise ValueError('Remap: `output_mapping` must be None if'
'`inplace` is set True.')
self.output_mapping = input_mapping
transforms: Union[Transform, List[Transform]] = None,
mapping: Optional[Dict] = None,
remapping: Optional[Dict] = None,
auto_remap: Optional[bool] = None,
allow_nonexist_keys: bool = False):
self.allow_nonexist_keys = allow_nonexist_keys
self.mapping = mapping
if auto_remap is None:
auto_remap = remapping is None
self.auto_remap = auto_remap
if self.auto_remap:
if remapping is not None:
raise ValueError('KeyMapper: ``remapping`` must be None if'
'`auto_remap` is set True.')
self.remapping = mapping
else:
self.output_mapping = output_mapping
self.remapping = remapping
if transforms is None:
transforms = []
self.transforms = Compose(transforms)
def __iter__(self):
"""Allow easy iteration over the transform sequence."""
return iter(self.transforms)
def remap_input(self, data: Dict, input_mapping: Dict) -> Dict[str, Any]:
"""Remap inputs for the wrapped transforms by gathering and renaming
data items according to the input_mapping.
def map_input(self, data: Dict, mapping: Dict) -> Dict[str, Any]:
"""KeyMapper inputs for the wrapped transforms by gathering and
renaming data items according to the mapping.
Args:
data (dict): The original input data
input_mapping (dict): The input key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details.
mapping (dict): The input key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
Returns:
dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline.
"""
def _remap(data, m):
def _map(data, m):
if isinstance(m, dict):
# m is a dict {inner_key:outer_key, ...}
return {k_in: _remap(data, k_out) for k_in, k_out in m.items()}
return {k_in: _map(data, k_out) for k_in, k_out in m.items()}
if isinstance(m, (tuple, list)):
# m is a list or tuple [outer_key1, outer_key2, ...]
# This is the case when we collect items from the original
# data to form a list or tuple to feed to the wrapped
# transforms.
return m.__class__(_remap(data, e) for e in m)
return m.__class__(_map(data, e) for e in m)
# m is an outer_key
if self.strict:
return data.get(m)
else:
if self.allow_nonexist_keys:
return data.get(m, NotInResults)
else:
return data.get(m)
collected = _remap(data, input_mapping)
collected = _map(data, mapping)
collected = {
k: v
for k, v in collected.items() if v is not NotInResults
......@@ -219,79 +224,78 @@ class Remap(BaseTransform):
return inputs
def remap_output(self, data: Dict, output_mapping: Dict) -> Dict[str, Any]:
"""Remap outputs from the wrapped transforms by gathering and renaming
data items according to the output_mapping.
def map_output(self, data: Dict, remapping: Dict) -> Dict[str, Any]:
"""KeyMapper outputs from the wrapped transforms by gathering and
renaming data items according to the remapping.
Args:
data (dict): The output of the wrapped pipeline.
output_mapping (dict): The output key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details.
remapping (dict): The output key mapping. See the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
Returns:
dict: The output with remapped keys.
"""
def _remap(data, m):
def _map(data, m):
if isinstance(m, dict):
assert isinstance(data, dict)
results = {}
for k_in, k_out in m.items():
assert k_in in data
results.update(_remap(data[k_in], k_out))
results.update(_map(data[k_in], k_out))
return results
if isinstance(m, (list, tuple)):
assert isinstance(data, (list, tuple))
assert len(data) == len(m)
results = {}
for m_i, d_i in zip(m, data):
results.update(_remap(d_i, m_i))
results.update(_map(d_i, m_i))
return results
return {m: data}
# Note that unmapped items are not retained, which is different from
# the behavior in remap_input. This is to avoid original data items
# the behavior in map_input. This is to avoid original data items
# being overwritten by intermediate namesakes
return _remap(data, output_mapping)
return _map(data, remapping)
def transform(self, results: Dict) -> Dict:
inputs = self.remap_input(results, self.input_mapping)
inputs = self.map_input(results, self.mapping)
outputs = self.transforms(inputs)
if self.output_mapping:
outputs = self.remap_output(outputs, self.output_mapping)
if self.remapping:
outputs = self.map_output(outputs, self.remapping)
results.update(outputs)
return results
@TRANSFORMS.register_module()
class ApplyToMultiple(Remap):
class TransformBroadcaster(KeyMapper):
"""A transform wrapper to apply the wrapped transforms to multiple data
items. For example, apply Resize to multiple images.
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping.
mapping (dict): A dict that defines the input key mapping.
Note that to apply the transforms to multiple data items, the
outer keys of the target items should be remapped as a list with
the standard inner key (The key required by the wrapped transform).
See the following example and the document of
`mmcv.transforms.wrappers.Remap` for details.
output_mapping (dict): A dict that defines the output key mapping.
``mmcv.transforms.wrappers.KeyMapper`` for details.
remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used
as the output_mapping. Note that if inplace is set True,
output_mapping should be None and strict should be True.
``mapping``. Default: None.
auto_remap (bool, optional): If True, an inverse of the mapping will
be used as the remapping. If auto_remap is not given, it will be
automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
share_random_params (bool): If True, the random transform
(e.g., RandomFlip) will be conducted in a deterministic way and
have the same behavior on all data items. For example, to randomly
......@@ -300,7 +304,7 @@ class ApplyToMultiple(Remap):
.. note::
To apply the transforms to each elements of a list or tuple, instead
of separating data items, you can remap the outer key of the target
of separating data items, you can map the outer key of the target
sequence to the standard inner key. See example 2.
example.
......@@ -309,13 +313,13 @@ class ApplyToMultiple(Remap):
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the
>>> # inner field and process them with wrapped transforms
>>> # TransformBroadcaster maps multiple outer fields to standard
>>> # the inner field and process them with wrapped transforms
>>> # respectively
>>> dict(type='ApplyToMultiple',
>>> dict(type='TransformBroadcaster',
>>> # case 1: from multiple outer fields
>>> input_mapping=dict(img=['lq', 'gt']),
>>> inplace=True,
>>> mapping=dict(img=['lq', 'gt']),
>>> auto_remap=True,
>>> # share_random_param=True means using identical random
>>> # parameters in every processing
>>> share_random_param=True,
......@@ -328,14 +332,14 @@ class ApplyToMultiple(Remap):
>>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the
>>> # inner field and process them with wrapped transforms
>>> # TransformBroadcaster maps multiple outer fields to standard
>>> # the inner field and process them with wrapped transforms
>>> # respectively
>>> dict(type='ApplyToMultiple',
>>> dict(type='TransformBroadcaster',
>>> # case 2: from one outer field that contains multiple
>>> # data elements (e.g. a list)
>>> # input_mapping=dict(img='images'),
>>> inplace=True,
>>> # mapping=dict(img='images'),
>>> auto_remap=True,
>>> share_random_param=True,
>>> transforms=[
>>> dict(type='Crop', crop_size=(384, 384)),
......@@ -346,13 +350,13 @@ class ApplyToMultiple(Remap):
def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
input_mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None,
inplace: bool = False,
strict: bool = True,
mapping: Optional[Dict] = None,
remapping: Optional[Dict] = None,
auto_remap: Optional[bool] = None,
allow_nonexist_keys: bool = False,
share_random_params: bool = False):
super().__init__(transforms, input_mapping, output_mapping, inplace,
strict)
super().__init__(transforms, mapping, remapping, auto_remap,
allow_nonexist_keys)
self.share_random_params = share_random_params
......@@ -360,7 +364,7 @@ class ApplyToMultiple(Remap):
# infer split number from input
seq_len = None
key_rep = None
for key in self.input_mapping:
for key in self.mapping:
assert isinstance(data[key], Sequence)
if seq_len is not None:
......@@ -375,14 +379,14 @@ class ApplyToMultiple(Remap):
scatters = []
for i in range(seq_len):
scatter = data.copy()
for key in self.input_mapping:
for key in self.mapping:
scatter[key] = data[key][i]
scatters.append(scatter)
return scatters
def transform(self, results: Dict):
# Apply input remapping
inputs = self.remap_input(results, self.input_mapping)
inputs = self.map_input(results, self.mapping)
# Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs)
......@@ -407,8 +411,8 @@ class ApplyToMultiple(Remap):
}
# Apply output remapping
if self.output_mapping:
outputs = self.remap_output(outputs, self.output_mapping)
if self.remapping:
outputs = self.map_output(outputs, self.remapping)
results.update(outputs)
return results
......@@ -419,9 +423,9 @@ class RandomChoice(BaseTransform):
"""Process data with a randomly chosen pipeline from given candidates.
Args:
pipelines (list[list]): A list of pipeline candidates, each is a
transforms (list[list]): A list of pipeline candidates, each is a
sequence of transforms.
pipeline_probs (list[float], optional): The probabilities associated
prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline
number and the sum should be 1. If not given, a uniform
distribution will be assumed.
......@@ -430,7 +434,7 @@ class RandomChoice(BaseTransform):
>>> # config
>>> pipeline = [
>>> dict(type='RandomChoice',
>>> pipelines=[
>>> transforms=[
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
>>> [dict(type='RandomRotate')], # subpipeline 2
>>> ]
......@@ -439,27 +443,27 @@ class RandomChoice(BaseTransform):
"""
def __init__(self,
pipelines: List[List[Union[Dict, Callable[[Dict], Dict]]]],
pipeline_probs: Optional[List[float]] = None):
transforms: List[Union[Transform, List[Transform]]],
prob: Optional[List[float]] = None):
if pipeline_probs is not None:
assert mmcv.is_seq_of(pipeline_probs, float)
assert len(pipelines) == len(pipeline_probs), \
'`pipelines` and `pipeline_probs` must have same lengths. ' \
f'Got {len(pipelines)} vs {len(pipeline_probs)}.'
assert sum(pipeline_probs) == 1
if prob is not None:
assert mmcv.is_seq_of(prob, float)
assert len(transforms) == len(prob), \
'``transforms`` and ``prob`` must have same lengths. ' \
f'Got {len(transforms)} vs {len(prob)}.'
assert sum(prob) == 1
self.pipeline_probs = pipeline_probs
self.pipelines = [Compose(transforms) for transforms in pipelines]
self.prob = prob
self.transforms = [Compose(transforms) for transforms in transforms]
def __iter__(self):
return iter(self.pipelines)
return iter(self.transforms)
@cacheable_method
@cache_randomness
def random_pipeline_index(self):
indices = np.arange(len(self.pipelines))
return np.random.choice(indices, p=self.pipeline_probs)
indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.prob)
def transform(self, results):
idx = self.random_pipeline_index()
return self.pipelines[idx](results)
return self.transforms[idx](results)
......@@ -6,9 +6,9 @@ import pytest
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cacheable_method
from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice,
Remap)
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomChoice,
TransformBroadcaster)
@TRANSFORMS.register_module()
......@@ -45,15 +45,18 @@ class AddToValue(BaseTransform):
class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']"""
def __init__(self) -> None:
def __init__(self, repeat=1) -> None:
super().__init__(addend=None)
self.repeat = repeat
@cacheable_method
@cache_randomness
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
return self.add(results, addend=self.get_random_addend())
for _ in range(self.repeat):
results = self.add(results, addend=self.get_random_addend())
return results
@TRANSFORMS.register_module()
......@@ -100,8 +103,8 @@ def test_cache_random_parameters():
transform = RandomAddToValue()
# Case 1: cache random parameters
assert hasattr(RandomAddToValue, '_cacheable_methods')
assert 'get_random_addend' in RandomAddToValue._cacheable_methods
assert hasattr(RandomAddToValue, '_methods_with_randomness')
assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
with cache_random_params(transform):
results_1 = transform(dict(value=0))
......@@ -114,12 +117,18 @@ def test_cache_random_parameters():
with pytest.raises(AssertionError):
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 3: invalid use of cacheable methods
# Case 3: allow to invoke random method 0 times
transform = RandomAddToValue(repeat=0)
with cache_random_params(transform):
_ = transform(dict(value=0))
# Case 4: NOT allow to invoke random method >1 times
transform = RandomAddToValue(repeat=2)
with pytest.raises(RuntimeError):
with cache_random_params(transform):
_ = transform.get_random_addend()
_ = transform(dict(value=0))
# Case 4: apply on nested transforms
# Case 5: apply on nested transforms
transform = Compose([RandomAddToValue()])
with cache_random_params(transform):
results_1 = transform(dict(value=0))
......@@ -127,13 +136,13 @@ def test_cache_random_parameters():
np.testing.assert_equal(results_1['value'], results_2['value'])
def test_remap():
def test_apply_to_mapped():
# Case 1: simple remap
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
results = dict(value=0, v_in=1)
results = pipeline(results)
......@@ -143,10 +152,10 @@ def test_remap():
np.testing.assert_equal(results['v_out'], 2)
# Case 2: collecting list
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2']))
mapping=dict(value=['v_in_1', 'v_in_2']),
remapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -159,10 +168,10 @@ def test_remap():
np.testing.assert_equal(results['v_out_2'], 4)
# Case 3: collecting dict
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
remapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
......@@ -174,11 +183,11 @@ def test_remap():
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 4: collecting list with inplace mode
pipeline = Remap(
# Case 4: collecting list with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True)
mapping=dict(value=['v_in_1', 'v_in_2']),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -188,11 +197,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 5: collecting dict with inplace mode
pipeline = Remap(
# Case 5: collecting dict with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True)
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
......@@ -202,11 +211,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 6: nested collection with inplace mode
pipeline = Remap(
# Case 6: nested collection with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True)
mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
auto_remap=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -218,27 +227,20 @@ def test_remap():
np.testing.assert_equal(results['v22'], 5)
np.testing.assert_equal(results['v3'], 6)
# Case 7: `strict` must be True if `inplace` is set True
# Case 7: output_map must be None if `auto_remap` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True,
strict=False)
# Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'),
inplace=True)
mapping=dict(value='v_in'),
remapping=dict(value='v_out'),
auto_remap=True)
# Case 9: non-strict input mapping
pipeline = Remap(
# Case 8: allow_nonexist_keys8
pipeline = KeyMapper(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='a', num_2='b'),
strict=False)
mapping=dict(num_1='a', num_2='b'),
auto_remap=False,
allow_nonexist_keys=True)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 3)
......@@ -246,11 +248,17 @@ def test_remap():
results = pipeline(dict(a=1))
assert np.isnan(results['sum'])
# Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
results = transform(dict(a=1))
# note that the original key 'a' will not be removed
assert results == dict(a=1, b=1)
# Test basic functions
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
# __iter__
for _ in pipeline:
......@@ -263,10 +271,10 @@ def test_remap():
def test_apply_to_multiple():
# Case 1: apply to list in results
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='values'),
inplace=True)
mapping=dict(value='values'),
auto_remap=True)
results = dict(values=[1, 2])
results = pipeline(results)
......@@ -274,10 +282,10 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['values'], [2, 3])
# Case 2: apply to multiple keys
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value=['v_1', 'v_2']),
inplace=True)
mapping=dict(value=['v_1', 'v_2']),
auto_remap=True)
results = dict(v_1=1, v_2=2)
results = pipeline(results)
......@@ -286,10 +294,11 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['v_2'], 3)
# Case 3: apply to multiple groups of keys
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
output_mapping=dict(sum=['a', 'b']))
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
......@@ -299,18 +308,19 @@ def test_apply_to_multiple():
# Case 4: inconsistent sequence length
with pytest.raises(ValueError):
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='list_1', num_2='list_2'))
mapping=dict(num_1='list_1', num_2='list_2'),
auto_remap=False)
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
_ = pipeline(results)
# Case 5: share random parameter
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[RandomAddToValue()],
input_mapping=dict(value='values'),
inplace=True,
mapping=dict(value='values'),
auto_remap=True,
share_random_params=True)
results = dict(values=[0, 0])
......@@ -326,27 +336,27 @@ def test_randomchoice():
# Case 1: given probability
pipeline = RandomChoice(
pipelines=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
pipeline_probs=[1.0, 0.0])
transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
prob=[1.0, 0.0])
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
# Case 2: default probability
pipeline = RandomChoice(pipelines=[[AddToValue(
pipeline = RandomChoice(transforms=[[AddToValue(
addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in ApplyToMultiple
pipeline = ApplyToMultiple(
# Case 3: nested RandomChoice in TransformBroadcaster
pipeline = TransformBroadcaster(
transforms=[
RandomChoice(
pipelines=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
transforms=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
input_mapping=dict(value='values'),
inplace=True,
mapping=dict(value='values'),
auto_remap=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
......@@ -357,10 +367,10 @@ def test_randomchoice():
def test_utils():
# Test cacheable_method: normal case
# Test cache_randomness: normal case
class DummyTransform(BaseTransform):
@cacheable_method
@cache_randomness
def func(self):
return np.random.rand()
......@@ -373,21 +383,21 @@ def test_utils():
with cache_random_params(transform):
_ = transform({})
# Test cacheable_method: invalid function type
# Test cache_randomness: invalid function type
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
@cache_randomness
@staticmethod
def func():
return np.random.rand()
# Test cacheable_method: invalid function argument list
# Test cache_randomness: invalid function argument list
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
@cache_randomness
def func(cls):
return np.random.rand()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment