"src/vscode:/vscode.git/clone" did not exist on "314a75f367a0d41e158edfc2edbb54eb0a1ae235"
Commit e2ca0733 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

Refactor base transform (#1830)

* rename cacheable_method to cache_randomness

* refactor transform wrappers and update docs

* fix all_nonexist_keys

* fix lint

* rename transform wrappers
parent 0a5b4125
...@@ -150,18 +150,18 @@ pipeline = [ ...@@ -150,18 +150,18 @@ pipeline = [
变换包装是一种特殊的数据变换类,他们本身并不操作数据字典中的图像、标签等信息,而是对其中定义的数据变换的行为进行增强。 变换包装是一种特殊的数据变换类,他们本身并不操作数据字典中的图像、标签等信息,而是对其中定义的数据变换的行为进行增强。
### 字段映射(Remap ### 字段映射(KeyMapper
字段映射包装(`Remap`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。 字段映射包装(`KeyMapper`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。
如果配合注册器和配置文件使用的话,在配置文件中数据集的 `pipeline` 中如下例使用字段映射包装: 如果配合注册器和配置文件使用的话,在配置文件中数据集的 `pipeline` 中如下例使用字段映射包装:
```python ```python
pipeline = [ pipeline = [
... ...
dict(type='Remap', dict(type='KeyMapper',
input_mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段 mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
inplace=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段 auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms=[ transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可 # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'), dict(type='RandomFlip'),
...@@ -182,7 +182,7 @@ pipeline = [ ...@@ -182,7 +182,7 @@ pipeline = [
pipeline = [ pipeline = [
... ...
dict(type='RandomChoice', dict(type='RandomChoice',
pipelines=[ transforms=[
[ [
dict(type='Posterize', bits=4), dict(type='Posterize', bits=4),
dict(type='Rotate', angle=30.) dict(type='Rotate', angle=30.)
...@@ -192,17 +192,17 @@ pipeline = [ ...@@ -192,17 +192,17 @@ pipeline = [
dict(type='Rotate', angle=30) dict(type='Rotate', angle=30)
], # 第二种随机变换组合 ], # 第二种随机变换组合
], ],
pipeline_probs=[0.4, 0.6] # 两种随机变换组合各自的选用概率 prob=[0.4, 0.6] # 两种随机变换组合各自的选用概率
) )
... ...
] ]
``` ```
### 多目标扩展(ApplyToMultiple ### 多目标扩展(TransformBroadcaster
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `Remap` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`ApplyToMultiple`)。 通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `KeyMapper` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`TransformBroadcaster`)。
多目标扩展包装(`ApplyToMultiple`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。 多目标扩展包装(`TransformBroadcaster`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。
1. 应用于多个字段 1. 应用于多个字段
...@@ -210,11 +210,11 @@ pipeline = [ ...@@ -210,11 +210,11 @@ pipeline = [
```python ```python
pipeline = [ pipeline = [
dict(type='ApplyToMultiple', dict(type='TransformBroadcaster',
# 分别应用于 "lq" 和 "gt" 两个字段,并将二者应设置 "img" 字段 # 分别应用于 "lq" 和 "gt" 两个字段,并将二者应设置 "img" 字段
input_mapping={'img': ['lq', 'gt']}, mapping={'img': ['lq', 'gt']},
# 在完成变换后,将 "img" 字段重映射回原先的字段 # 在完成变换后,将 "img" 字段重映射回原先的字段
inplace=True, auto_remap=True,
# 是否在对各目标的变换中共享随机变量 # 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享) # 更多介绍参加后续章节(随机变量共享)
share_random_param=True, share_random_param=True,
...@@ -231,11 +231,11 @@ pipeline = [ ...@@ -231,11 +231,11 @@ pipeline = [
```python ```python
pipeline = [ pipeline = [
dict(type='ApplyToMultiple', dict(type='TransformBroadcaster',
# 将 "images" 字段下的每张图片映射至 "img" 字段 # 将 "images" 字段下的每张图片映射至 "img" 字段
input_mapping={'img': 'images'}, mapping={'img': 'images'},
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中 # 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
inplace=True, auto_remap=True,
# 是否在对各目标的变换中共享随机变量 # 是否在对各目标的变换中共享随机变量
share_random_param=True, share_random_param=True,
transforms=[ transforms=[
...@@ -245,12 +245,12 @@ pipeline = [ ...@@ -245,12 +245,12 @@ pipeline = [
] ]
``` ```
`ApplyToMultiple` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。 `TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转: 以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
```python ```python
from mmcv.transforms.utils import cacheable_method from mmcv.transforms.utils import cache_randomness
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class MyRandomFlip(BaseTransform): class MyRandomFlip(BaseTransform):
...@@ -259,7 +259,7 @@ class MyRandomFlip(BaseTransform): ...@@ -259,7 +259,7 @@ class MyRandomFlip(BaseTransform):
self.prob = prob self.prob = prob
self.direction = direction self.direction = direction
@cacheable_method # 标注该方法的输出为可共享的随机变量 @cache_randomness # 标注该方法的输出为可共享的随机变量
def do_flip(self): def do_flip(self):
flip = True if random.random() > self.prob else False flip = True if random.random() > self.prob else False
return flip return flip
...@@ -271,4 +271,4 @@ class MyRandomFlip(BaseTransform): ...@@ -271,4 +271,4 @@ class MyRandomFlip(BaseTransform):
return results return results
``` ```
通过 `cacheable_method` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `ApplyToMultiple` 对多个目标的变换中,这一变量的值都会保持一致。 通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
...@@ -4,24 +4,24 @@ from .loading import LoadAnnotation, LoadImageFromFile ...@@ -4,24 +4,24 @@ from .loading import LoadAnnotation, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad, from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize, RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomResize, Resize) RandomResize, Resize)
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
try: try:
import torch # noqa: F401 import torch # noqa: F401
except ImportError: except ImportError:
__all__ = [ __all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap', 'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop', 'Resize', 'Pad', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize' 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
] ]
else: else:
from .formatting import ImageToTensor, ToTensor, to_tensor from .formatting import ImageToTensor, ToTensor, to_tensor
__all__ = [ __all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap', 'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'MultiScaleFlipAug', 'RandomResize' 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
] ]
...@@ -11,17 +11,17 @@ from typing import Callable, Union ...@@ -11,17 +11,17 @@ from typing import Callable, Union
from .base import BaseTransform from .base import BaseTransform
class cacheable_method: class cache_randomness:
"""Decorator that marks a method of a transform class as a cacheable """Decorator that marks the method with random return value(s) in a
method. transform class.
This decorator is usually used together with the context-manager This decorator is usually used together with the context-manager
:func`:cache_random_params`. In this context, a cacheable method will :func`:cache_random_params`. In this context, a decorated method will
cache its return value(s) at the first time of being invoked, and always cache its return value(s) at the first time of being invoked, and always
return the cached values when being invoked again. return the cached values when being invoked again.
.. note:: .. note::
Only a instance method can be decorated as a cacheable_method. Only an instance method can be decorated by ``cache_randomness``.
""" """
def __init__(self, func): def __init__(self, func):
...@@ -29,12 +29,12 @@ class cacheable_method: ...@@ -29,12 +29,12 @@ class cacheable_method:
# Check `func` is to be bound as an instance method # Check `func` is to be bound as an instance method
if not inspect.isfunction(func): if not inspect.isfunction(func):
raise TypeError('Unsupport callable to decorate with' raise TypeError('Unsupport callable to decorate with'
'@cacheable_method.') '@cache_randomness.')
func_args = inspect.getfullargspec(func).args func_args = inspect.getfullargspec(func).args
if len(func_args) == 0 or func_args[0] != 'self': if len(func_args) == 0 or func_args[0] != 'self':
raise TypeError( raise TypeError(
'@cacheable_method should only be used to decorate ' '@cache_randomness should only be used to decorate '
'instance methods (the first argument is `self`).') 'instance methods (the first argument is ``self``).')
functools.update_wrapper(self, func) functools.update_wrapper(self, func)
self.func = func self.func = func
...@@ -42,24 +42,24 @@ class cacheable_method: ...@@ -42,24 +42,24 @@ class cacheable_method:
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
# Maintain a record of decorated methods in the class # Maintain a record of decorated methods in the class
if not hasattr(owner, '_cacheable_methods'): if not hasattr(owner, '_methods_with_randomness'):
setattr(owner, '_cacheable_methods', []) setattr(owner, '_methods_with_randomness', [])
owner._cacheable_methods.append(self.__name__) owner._methods_with_randomness.append(self.__name__)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated # Get the transform instance whose method is decorated
# by cacheable_method # by cache_randomness
instance = self.instance_ref() instance = self.instance_ref()
name = self.__name__ name = self.__name__
# Check the flag `self._cache_enabled`, which should be # Check the flag ``self._cache_enabled``, which should be
# set by the contextmanagers like `cache_random_parameters` # set by the contextmanagers like ``cache_random_parameters```
cache_enabled = getattr(instance, '_cache_enabled', False) cache_enabled = getattr(instance, '_cache_enabled', False)
if cache_enabled: if cache_enabled:
# Initialize the cache of the transform instances. The flag # Initialize the cache of the transform instances. The flag
# `cache_enabled` is set by contextmanagers like # ``cache_enabled``` is set by contextmanagers like
# `cache_random_params`. # ``cache_random_params```.
if not hasattr(instance, '_cache'): if not hasattr(instance, '_cache'):
setattr(instance, '_cache', {}) setattr(instance, '_cache', {})
...@@ -81,13 +81,13 @@ class cacheable_method: ...@@ -81,13 +81,13 @@ class cacheable_method:
@contextmanager @contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]): def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of cacheable methods in """Context-manager that enables the cache of return values of methods
transforms. decorated by ``cache_randomness`` in transforms.
In this mode, cacheable methods will cache their return values on the In this mode, decorated methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow first invoking, and always return the cached value afterward. This allow
to apply random transforms in a deterministic way. For example, apply same to apply random transforms in a deterministic way. For example, apply same
transforms on multiple examples. See `cacheable_method` for more transforms on multiple examples. See ``cache_randomness`` for more
information. information.
Args: Args:
...@@ -99,12 +99,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -99,12 +99,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# ones. These methods will be restituted when exiting the context. # ones. These methods will be restituted when exiting the context.
key2method = dict() key2method = dict()
# key2counter stores the usage number of each cacheable_method. This is # key2counter stores the usage number of each cache_randomness. This is
# used to check that any cacheable_method is invoked once during processing # used to check that any cache_randomness is invoked once during processing
# on data sample. # on data sample.
key2counter = defaultdict(int) key2counter = defaultdict(int)
def _add_counter(obj, method_name): def _add_invoke_counter(obj, method_name):
method = getattr(obj, method_name) method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}' key = f'{id(obj)}.{method_name}'
key2method[key] = method key2method[key] = method
...@@ -116,15 +116,43 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -116,15 +116,43 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return wrapped return wrapped
def _add_invoke_checker(obj, method_name):
# check that the method in _methods_with_randomness has been
# invoked at most once
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
@functools.wraps(method)
def wrapped(*args, **kwargs):
# clear counter
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
key2counter[key] = 0
output = method(*args, **kwargs)
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
if key2counter[key] > 1:
raise RuntimeError(
'The method decorated by ``cache_randomness`` should '
'be invoked at most once during processing one data '
f'sample. The method {name} of {obj} has been invoked'
f' {key2counter[key]} times.')
return output
return wrapped
def _start_cache(t: BaseTransform): def _start_cache(t: BaseTransform):
# Set cache enabled flag # Set cache enabled flag
setattr(t, '_cache_enabled', True) setattr(t, '_cache_enabled', True)
# Store the original method and init the counter # Store the original method and init the counter
if hasattr(t, '_cacheable_methods'): if hasattr(t, '_methods_with_randomness'):
setattr(t, 'transform', _add_counter(t, 'transform')) setattr(t, 'transform', _add_invoke_checker(t, 'transform'))
for name in t._cacheable_methods: for name in t._methods_with_randomness:
setattr(t, name, _add_counter(t, name)) setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform): def _end_cache(t: BaseTransform):
# Remove cache enabled flag # Remove cache enabled flag
...@@ -133,17 +161,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -133,17 +161,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
del t._cache del t._cache
# Restore the original method # Restore the original method
if hasattr(t, '_cacheable_methods'): if hasattr(t, '_methods_with_randomness'):
key_transform = f'{id(t)}.transform' for name in t._methods_with_randomness:
for name in t._cacheable_methods:
key = f'{id(t)}.{name}' key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only '
f'once during processing one data sample. {t} got '
f'unmatched number of {key2counter[key]} ({name}) vs '
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key]) setattr(t, name, key2method[key])
key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform]) setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable], def _apply(t: Union[BaseTransform, Iterable],
...@@ -151,7 +174,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -151,7 +174,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Note that BaseTransform and Iterable are not mutually exclusive, # Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice # e.g. Compose, RandomChoice
if isinstance(t, BaseTransform): if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'): if hasattr(t, '_methods_with_randomness'):
func(t) func(t)
if isinstance(t, Iterable): if isinstance(t, Iterable):
for _t in t: for _t in t:
......
...@@ -8,7 +8,10 @@ import numpy as np ...@@ -8,7 +8,10 @@ import numpy as np
import mmcv import mmcv
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .utils import cache_random_params, cacheable_method from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator for required but missing keys in results # Indicator for required but missing keys in results
NotInResults = object() NotInResults = object()
...@@ -46,8 +49,9 @@ class Compose(BaseTransform): ...@@ -46,8 +49,9 @@ class Compose(BaseTransform):
>>> ] >>> ]
""" """
def __init__(self, transforms: List[Union[Dict, Callable[[Dict], Dict]]]): def __init__(self, transforms: Union[Transform, List[Transform]]):
assert isinstance(transforms, Sequence) if not isinstance(transforms, list):
transforms = [transforms]
self.transforms = [] self.transforms = []
for transform in transforms: for transform in transforms:
if isinstance(transform, dict): if isinstance(transform, dict):
...@@ -88,42 +92,42 @@ class Compose(BaseTransform): ...@@ -88,42 +92,42 @@ class Compose(BaseTransform):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class Remap(BaseTransform): class KeyMapper(BaseTransform):
"""A transform wrapper to remap and reorganize the input/output of the """A transform wrapper to map and reorganize the input/output of the
wrapped transforms (or sub-pipeline). wrapped transforms (or sub-pipeline).
Args: Args:
transforms (list[dict | callable]): Sequence of transform object or transforms (list[dict | callable], optional): Sequence of transform
config dict to be wrapped. object or config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping. mapping (dict): A dict that defines the input key mapping.
The keys corresponds to the inner key (i.e., kwargs of the The keys corresponds to the inner key (i.e., kwargs of the
`transform` method), and should be string type. The values ``transform`` method), and should be string type. The values
corresponds to the outer keys (i.e., the keys of the corresponds to the outer keys (i.e., the keys of the
data/results), and should have a type of string, list or dict. data/results), and should have a type of string, list or dict.
None means not applying input mapping. Default: None. None means not applying input mapping. Default: None.
output_mapping (dict): A dict that defines the output key mapping. remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None. ``mapping``. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used auto_remap (bool, optional): If True, an inverse of the mapping will
as the output_mapping. Note that if inplace is set True, be used as the remapping. If auto_remap is not given, it will be
output_mapping should be None and strict should be True. automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False. Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
Examples: Examples:
>>> # Example 1: Remap 'gt_img' to 'img' >>> # Example 1: KeyMapper 'gt_img' to 'img'
>>> pipeline = [ >>> pipeline = [
>>> # Use Remap to convert outer (original) field name 'gt_img' >>> # Use KeyMapper to convert outer (original) field name
>>> # to inner (used by inner transforms) filed name 'img' >>> # 'gt_img' to inner (used by inner transforms) filed name
>>> dict(type='Remap', >>> # 'img'
>>> input_mapping=dict(img='gt_img'), >>> dict(type='KeyMapper',
>>> # inplace=True means output key mapping is the revert of >>> mapping=dict(img='gt_img'),
>>> # auto_remap=True means output key mapping is the revert of
>>> # the input key mapping, e.g. inner 'img' will be mapped >>> # the input key mapping, e.g. inner 'img' will be mapped
>>> # back to outer 'gt_img' >>> # back to outer 'gt_img'
>>> inplace=True, >>> auto_remap=True,
>>> transforms=[ >>> transforms=[
>>> # In all transforms' implementation just use 'img' >>> # In all transforms' implementation just use 'img'
>>> # as a standard field name >>> # as a standard field name
...@@ -136,10 +140,10 @@ class Remap(BaseTransform): ...@@ -136,10 +140,10 @@ class Remap(BaseTransform):
>>> # The inner field 'imgs' will be a dict with keys 'img_src' >>> # The inner field 'imgs' will be a dict with keys 'img_src'
>>> # and 'img_tar', whose values are outer fields 'img1' and >>> # and 'img_tar', whose values are outer fields 'img1' and
>>> # 'img2' respectively. >>> # 'img2' respectively.
>>> dict(type='Remap', >>> dict(type='KeyMapper',
>>> dict( >>> dict(
>>> type='Remap', >>> type='KeyMapper',
>>> input_mapping=dict( >>> mapping=dict(
>>> imgs=dict( >>> imgs=dict(
>>> img_src='img1', >>> img_src='img1',
>>> img_tar='img2')), >>> img_tar='img2')),
...@@ -148,66 +152,67 @@ class Remap(BaseTransform): ...@@ -148,66 +152,67 @@ class Remap(BaseTransform):
""" """
def __init__(self, def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]], transforms: Union[Transform, List[Transform]] = None,
input_mapping: Optional[Dict] = None, mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None, remapping: Optional[Dict] = None,
inplace: bool = False, auto_remap: Optional[bool] = None,
strict: bool = True): allow_nonexist_keys: bool = False):
self.inplace = inplace self.allow_nonexist_keys = allow_nonexist_keys
self.strict = strict self.mapping = mapping
self.input_mapping = input_mapping
if auto_remap is None:
if self.inplace: auto_remap = remapping is None
if not self.strict: self.auto_remap = auto_remap
raise ValueError('Remap: `strict` must be set True if'
'`inplace` is set True.') if self.auto_remap:
if remapping is not None:
if output_mapping is not None: raise ValueError('KeyMapper: ``remapping`` must be None if'
raise ValueError('Remap: `output_mapping` must be None if' '`auto_remap` is set True.')
'`inplace` is set True.') self.remapping = mapping
self.output_mapping = input_mapping
else: else:
self.output_mapping = output_mapping self.remapping = remapping
if transforms is None:
transforms = []
self.transforms = Compose(transforms) self.transforms = Compose(transforms)
def __iter__(self): def __iter__(self):
"""Allow easy iteration over the transform sequence.""" """Allow easy iteration over the transform sequence."""
return iter(self.transforms) return iter(self.transforms)
def remap_input(self, data: Dict, input_mapping: Dict) -> Dict[str, Any]: def map_input(self, data: Dict, mapping: Dict) -> Dict[str, Any]:
"""Remap inputs for the wrapped transforms by gathering and renaming """KeyMapper inputs for the wrapped transforms by gathering and
data items according to the input_mapping. renaming data items according to the mapping.
Args: Args:
data (dict): The original input data data (dict): The original input data
input_mapping (dict): The input key mapping. See the document of mapping (dict): The input key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details. ``mmcv.transforms.wrappers.KeyMapper`` for details.
Returns: Returns:
dict: The input data with remapped keys. This will be the actual dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline. input of the wrapped pipeline.
""" """
def _remap(data, m): def _map(data, m):
if isinstance(m, dict): if isinstance(m, dict):
# m is a dict {inner_key:outer_key, ...} # m is a dict {inner_key:outer_key, ...}
return {k_in: _remap(data, k_out) for k_in, k_out in m.items()} return {k_in: _map(data, k_out) for k_in, k_out in m.items()}
if isinstance(m, (tuple, list)): if isinstance(m, (tuple, list)):
# m is a list or tuple [outer_key1, outer_key2, ...] # m is a list or tuple [outer_key1, outer_key2, ...]
# This is the case when we collect items from the original # This is the case when we collect items from the original
# data to form a list or tuple to feed to the wrapped # data to form a list or tuple to feed to the wrapped
# transforms. # transforms.
return m.__class__(_remap(data, e) for e in m) return m.__class__(_map(data, e) for e in m)
# m is an outer_key # m is an outer_key
if self.strict: if self.allow_nonexist_keys:
return data.get(m)
else:
return data.get(m, NotInResults) return data.get(m, NotInResults)
else:
return data.get(m)
collected = _remap(data, input_mapping) collected = _map(data, mapping)
collected = { collected = {
k: v k: v
for k, v in collected.items() if v is not NotInResults for k, v in collected.items() if v is not NotInResults
...@@ -219,79 +224,78 @@ class Remap(BaseTransform): ...@@ -219,79 +224,78 @@ class Remap(BaseTransform):
return inputs return inputs
def remap_output(self, data: Dict, output_mapping: Dict) -> Dict[str, Any]: def map_output(self, data: Dict, remapping: Dict) -> Dict[str, Any]:
"""Remap outputs from the wrapped transforms by gathering and renaming """KeyMapper outputs from the wrapped transforms by gathering and
data items according to the output_mapping. renaming data items according to the remapping.
Args: Args:
data (dict): The output of the wrapped pipeline. data (dict): The output of the wrapped pipeline.
output_mapping (dict): The output key mapping. See the document of remapping (dict): The output key mapping. See the document of
`mmcv.transforms.wrappers.Remap` for details. ``mmcv.transforms.wrappers.KeyMapper`` for details.
Returns: Returns:
dict: The output with remapped keys. dict: The output with remapped keys.
""" """
def _remap(data, m): def _map(data, m):
if isinstance(m, dict): if isinstance(m, dict):
assert isinstance(data, dict) assert isinstance(data, dict)
results = {} results = {}
for k_in, k_out in m.items(): for k_in, k_out in m.items():
assert k_in in data assert k_in in data
results.update(_remap(data[k_in], k_out)) results.update(_map(data[k_in], k_out))
return results return results
if isinstance(m, (list, tuple)): if isinstance(m, (list, tuple)):
assert isinstance(data, (list, tuple)) assert isinstance(data, (list, tuple))
assert len(data) == len(m) assert len(data) == len(m)
results = {} results = {}
for m_i, d_i in zip(m, data): for m_i, d_i in zip(m, data):
results.update(_remap(d_i, m_i)) results.update(_map(d_i, m_i))
return results return results
return {m: data} return {m: data}
# Note that unmapped items are not retained, which is different from # Note that unmapped items are not retained, which is different from
# the behavior in remap_input. This is to avoid original data items # the behavior in map_input. This is to avoid original data items
# being overwritten by intermediate namesakes # being overwritten by intermediate namesakes
return _remap(data, output_mapping) return _map(data, remapping)
def transform(self, results: Dict) -> Dict: def transform(self, results: Dict) -> Dict:
inputs = self.remap_input(results, self.input_mapping) inputs = self.map_input(results, self.mapping)
outputs = self.transforms(inputs) outputs = self.transforms(inputs)
if self.output_mapping: if self.remapping:
outputs = self.remap_output(outputs, self.output_mapping) outputs = self.map_output(outputs, self.remapping)
results.update(outputs) results.update(outputs)
return results return results
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class ApplyToMultiple(Remap): class TransformBroadcaster(KeyMapper):
"""A transform wrapper to apply the wrapped transforms to multiple data """A transform wrapper to apply the wrapped transforms to multiple data
items. For example, apply Resize to multiple images. items. For example, apply Resize to multiple images.
Args: Args:
transforms (list[dict | callable]): Sequence of transform object or transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped. config dict to be wrapped.
input_mapping (dict): A dict that defines the input key mapping. mapping (dict): A dict that defines the input key mapping.
Note that to apply the transforms to multiple data items, the Note that to apply the transforms to multiple data items, the
outer keys of the target items should be remapped as a list with outer keys of the target items should be remapped as a list with
the standard inner key (The key required by the wrapped transform). the standard inner key (The key required by the wrapped transform).
See the following example and the document of See the following example and the document of
`mmcv.transforms.wrappers.Remap` for details. ``mmcv.transforms.wrappers.KeyMapper`` for details.
output_mapping (dict): A dict that defines the output key mapping. remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the The keys and values have the same meanings and rules as in the
`input_mapping`. Default: None. ``mapping``. Default: None.
inplace (bool): If True, an inverse of the input_mapping will be used auto_remap (bool, optional): If True, an inverse of the mapping will
as the output_mapping. Note that if inplace is set True, be used as the remapping. If auto_remap is not given, it will be
output_mapping should be None and strict should be True. automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False. Default: False.
strict (bool): If True, the outer keys in the input_mapping must exist
in the input data, or an exception will be raised. If False,
the missing keys will be assigned a special value `NotInResults`
during input remapping. Default: True.
share_random_params (bool): If True, the random transform share_random_params (bool): If True, the random transform
(e.g., RandomFlip) will be conducted in a deterministic way and (e.g., RandomFlip) will be conducted in a deterministic way and
have the same behavior on all data items. For example, to randomly have the same behavior on all data items. For example, to randomly
...@@ -300,7 +304,7 @@ class ApplyToMultiple(Remap): ...@@ -300,7 +304,7 @@ class ApplyToMultiple(Remap):
.. note:: .. note::
To apply the transforms to each elements of a list or tuple, instead To apply the transforms to each elements of a list or tuple, instead
of separating data items, you can remap the outer key of the target of separating data items, you can map the outer key of the target
sequence to the standard inner key. See example 2. sequence to the standard inner key. See example 2.
example. example.
...@@ -309,13 +313,13 @@ class ApplyToMultiple(Remap): ...@@ -309,13 +313,13 @@ class ApplyToMultiple(Remap):
>>> pipeline = [ >>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the >>> # TransformBroadcaster maps multiple outer fields to standard
>>> # inner field and process them with wrapped transforms >>> # the inner field and process them with wrapped transforms
>>> # respectively >>> # respectively
>>> dict(type='ApplyToMultiple', >>> dict(type='TransformBroadcaster',
>>> # case 1: from multiple outer fields >>> # case 1: from multiple outer fields
>>> input_mapping=dict(img=['lq', 'gt']), >>> mapping=dict(img=['lq', 'gt']),
>>> inplace=True, >>> auto_remap=True,
>>> # share_random_param=True means using identical random >>> # share_random_param=True means using identical random
>>> # parameters in every processing >>> # parameters in every processing
>>> share_random_param=True, >>> share_random_param=True,
...@@ -328,14 +332,14 @@ class ApplyToMultiple(Remap): ...@@ -328,14 +332,14 @@ class ApplyToMultiple(Remap):
>>> pipeline = [ >>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
>>> # ApplyToMultiple maps multiple outer fields to standard the >>> # TransformBroadcaster maps multiple outer fields to standard
>>> # inner field and process them with wrapped transforms >>> # the inner field and process them with wrapped transforms
>>> # respectively >>> # respectively
>>> dict(type='ApplyToMultiple', >>> dict(type='TransformBroadcaster',
>>> # case 2: from one outer field that contains multiple >>> # case 2: from one outer field that contains multiple
>>> # data elements (e.g. a list) >>> # data elements (e.g. a list)
>>> # input_mapping=dict(img='images'), >>> # mapping=dict(img='images'),
>>> inplace=True, >>> auto_remap=True,
>>> share_random_param=True, >>> share_random_param=True,
>>> transforms=[ >>> transforms=[
>>> dict(type='Crop', crop_size=(384, 384)), >>> dict(type='Crop', crop_size=(384, 384)),
...@@ -346,13 +350,13 @@ class ApplyToMultiple(Remap): ...@@ -346,13 +350,13 @@ class ApplyToMultiple(Remap):
def __init__(self, def __init__(self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]], transforms: List[Union[Dict, Callable[[Dict], Dict]]],
input_mapping: Optional[Dict] = None, mapping: Optional[Dict] = None,
output_mapping: Optional[Dict] = None, remapping: Optional[Dict] = None,
inplace: bool = False, auto_remap: Optional[bool] = None,
strict: bool = True, allow_nonexist_keys: bool = False,
share_random_params: bool = False): share_random_params: bool = False):
super().__init__(transforms, input_mapping, output_mapping, inplace, super().__init__(transforms, mapping, remapping, auto_remap,
strict) allow_nonexist_keys)
self.share_random_params = share_random_params self.share_random_params = share_random_params
...@@ -360,7 +364,7 @@ class ApplyToMultiple(Remap): ...@@ -360,7 +364,7 @@ class ApplyToMultiple(Remap):
# infer split number from input # infer split number from input
seq_len = None seq_len = None
key_rep = None key_rep = None
for key in self.input_mapping: for key in self.mapping:
assert isinstance(data[key], Sequence) assert isinstance(data[key], Sequence)
if seq_len is not None: if seq_len is not None:
...@@ -375,14 +379,14 @@ class ApplyToMultiple(Remap): ...@@ -375,14 +379,14 @@ class ApplyToMultiple(Remap):
scatters = [] scatters = []
for i in range(seq_len): for i in range(seq_len):
scatter = data.copy() scatter = data.copy()
for key in self.input_mapping: for key in self.mapping:
scatter[key] = data[key][i] scatter[key] = data[key][i]
scatters.append(scatter) scatters.append(scatter)
return scatters return scatters
def transform(self, results: Dict): def transform(self, results: Dict):
# Apply input remapping # Apply input remapping
inputs = self.remap_input(results, self.input_mapping) inputs = self.map_input(results, self.mapping)
# Scatter sequential inputs into a list # Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs) inputs = self.scatter_sequence(inputs)
...@@ -407,8 +411,8 @@ class ApplyToMultiple(Remap): ...@@ -407,8 +411,8 @@ class ApplyToMultiple(Remap):
} }
# Apply output remapping # Apply output remapping
if self.output_mapping: if self.remapping:
outputs = self.remap_output(outputs, self.output_mapping) outputs = self.map_output(outputs, self.remapping)
results.update(outputs) results.update(outputs)
return results return results
...@@ -419,9 +423,9 @@ class RandomChoice(BaseTransform): ...@@ -419,9 +423,9 @@ class RandomChoice(BaseTransform):
"""Process data with a randomly chosen pipeline from given candidates. """Process data with a randomly chosen pipeline from given candidates.
Args: Args:
pipelines (list[list]): A list of pipeline candidates, each is a transforms (list[list]): A list of pipeline candidates, each is a
sequence of transforms. sequence of transforms.
pipeline_probs (list[float], optional): The probabilities associated prob (list[float], optional): The probabilities associated
with each pipeline. The length should be equal to the pipeline with each pipeline. The length should be equal to the pipeline
number and the sum should be 1. If not given, a uniform number and the sum should be 1. If not given, a uniform
distribution will be assumed. distribution will be assumed.
...@@ -430,7 +434,7 @@ class RandomChoice(BaseTransform): ...@@ -430,7 +434,7 @@ class RandomChoice(BaseTransform):
>>> # config >>> # config
>>> pipeline = [ >>> pipeline = [
>>> dict(type='RandomChoice', >>> dict(type='RandomChoice',
>>> pipelines=[ >>> transforms=[
>>> [dict(type='RandomHorizontalFlip')], # subpipeline 1 >>> [dict(type='RandomHorizontalFlip')], # subpipeline 1
>>> [dict(type='RandomRotate')], # subpipeline 2 >>> [dict(type='RandomRotate')], # subpipeline 2
>>> ] >>> ]
...@@ -439,27 +443,27 @@ class RandomChoice(BaseTransform): ...@@ -439,27 +443,27 @@ class RandomChoice(BaseTransform):
""" """
def __init__(self, def __init__(self,
pipelines: List[List[Union[Dict, Callable[[Dict], Dict]]]], transforms: List[Union[Transform, List[Transform]]],
pipeline_probs: Optional[List[float]] = None): prob: Optional[List[float]] = None):
if pipeline_probs is not None: if prob is not None:
assert mmcv.is_seq_of(pipeline_probs, float) assert mmcv.is_seq_of(prob, float)
assert len(pipelines) == len(pipeline_probs), \ assert len(transforms) == len(prob), \
'`pipelines` and `pipeline_probs` must have same lengths. ' \ '``transforms`` and ``prob`` must have same lengths. ' \
f'Got {len(pipelines)} vs {len(pipeline_probs)}.' f'Got {len(transforms)} vs {len(prob)}.'
assert sum(pipeline_probs) == 1 assert sum(prob) == 1
self.pipeline_probs = pipeline_probs self.prob = prob
self.pipelines = [Compose(transforms) for transforms in pipelines] self.transforms = [Compose(transforms) for transforms in transforms]
def __iter__(self): def __iter__(self):
return iter(self.pipelines) return iter(self.transforms)
@cacheable_method @cache_randomness
def random_pipeline_index(self): def random_pipeline_index(self):
indices = np.arange(len(self.pipelines)) indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.pipeline_probs) return np.random.choice(indices, p=self.prob)
def transform(self, results): def transform(self, results):
idx = self.random_pipeline_index() idx = self.random_pipeline_index()
return self.pipelines[idx](results) return self.transforms[idx](results)
...@@ -6,9 +6,9 @@ import pytest ...@@ -6,9 +6,9 @@ import pytest
from mmcv.transforms.base import BaseTransform from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cacheable_method from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice, from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomChoice,
Remap) TransformBroadcaster)
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
...@@ -45,15 +45,18 @@ class AddToValue(BaseTransform): ...@@ -45,15 +45,18 @@ class AddToValue(BaseTransform):
class RandomAddToValue(AddToValue): class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']""" """Dummy transform to add a random addend to results['value']"""
def __init__(self) -> None: def __init__(self, repeat=1) -> None:
super().__init__(addend=None) super().__init__(addend=None)
self.repeat = repeat
@cacheable_method @cache_randomness
def get_random_addend(self): def get_random_addend(self):
return np.random.rand() return np.random.rand()
def transform(self, results): def transform(self, results):
return self.add(results, addend=self.get_random_addend()) for _ in range(self.repeat):
results = self.add(results, addend=self.get_random_addend())
return results
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
...@@ -100,8 +103,8 @@ def test_cache_random_parameters(): ...@@ -100,8 +103,8 @@ def test_cache_random_parameters():
transform = RandomAddToValue() transform = RandomAddToValue()
# Case 1: cache random parameters # Case 1: cache random parameters
assert hasattr(RandomAddToValue, '_cacheable_methods') assert hasattr(RandomAddToValue, '_methods_with_randomness')
assert 'get_random_addend' in RandomAddToValue._cacheable_methods assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
with cache_random_params(transform): with cache_random_params(transform):
results_1 = transform(dict(value=0)) results_1 = transform(dict(value=0))
...@@ -114,12 +117,18 @@ def test_cache_random_parameters(): ...@@ -114,12 +117,18 @@ def test_cache_random_parameters():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
np.testing.assert_equal(results_1['value'], results_2['value']) np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 3: invalid use of cacheable methods # Case 3: allow to invoke random method 0 times
transform = RandomAddToValue(repeat=0)
with cache_random_params(transform):
_ = transform(dict(value=0))
# Case 4: NOT allow to invoke random method >1 times
transform = RandomAddToValue(repeat=2)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
with cache_random_params(transform): with cache_random_params(transform):
_ = transform.get_random_addend() _ = transform(dict(value=0))
# Case 4: apply on nested transforms # Case 5: apply on nested transforms
transform = Compose([RandomAddToValue()]) transform = Compose([RandomAddToValue()])
with cache_random_params(transform): with cache_random_params(transform):
results_1 = transform(dict(value=0)) results_1 = transform(dict(value=0))
...@@ -127,13 +136,13 @@ def test_cache_random_parameters(): ...@@ -127,13 +136,13 @@ def test_cache_random_parameters():
np.testing.assert_equal(results_1['value'], results_2['value']) np.testing.assert_equal(results_1['value'], results_2['value'])
def test_remap(): def test_apply_to_mapped():
# Case 1: simple remap # Case 1: simple remap
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), mapping=dict(value='v_in'),
output_mapping=dict(value='v_out')) remapping=dict(value='v_out'))
results = dict(value=0, v_in=1) results = dict(value=0, v_in=1)
results = pipeline(results) results = pipeline(results)
...@@ -143,10 +152,10 @@ def test_remap(): ...@@ -143,10 +152,10 @@ def test_remap():
np.testing.assert_equal(results['v_out'], 2) np.testing.assert_equal(results['v_out'], 2)
# Case 2: collecting list # Case 2: collecting list
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']), mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2'])) remapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'): with pytest.warns(UserWarning, match='value is a list'):
...@@ -159,10 +168,10 @@ def test_remap(): ...@@ -159,10 +168,10 @@ def test_remap():
np.testing.assert_equal(results['v_out_2'], 4) np.testing.assert_equal(results['v_out_2'], 4)
# Case 3: collecting dict # Case 3: collecting dict
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2'))) remapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'): with pytest.warns(UserWarning, match='value is a dict'):
...@@ -174,11 +183,11 @@ def test_remap(): ...@@ -174,11 +183,11 @@ def test_remap():
np.testing.assert_equal(results['v_out_1'], 3) np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4) np.testing.assert_equal(results['v_out_2'], 4)
# Case 4: collecting list with inplace mode # Case 4: collecting list with auto_remap mode
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']), mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True) auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'): with pytest.warns(UserWarning, match='value is a list'):
...@@ -188,11 +197,11 @@ def test_remap(): ...@@ -188,11 +197,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3) np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4) np.testing.assert_equal(results['v_in_2'], 4)
# Case 5: collecting dict with inplace mode # Case 5: collecting dict with auto_remap mode
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True) auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'): with pytest.warns(UserWarning, match='value is a dict'):
...@@ -202,11 +211,11 @@ def test_remap(): ...@@ -202,11 +211,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3) np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4) np.testing.assert_equal(results['v_in_2'], 4)
# Case 6: nested collection with inplace mode # Case 6: nested collection with auto_remap mode
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]), mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True) auto_remap=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4) results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
with pytest.warns(UserWarning, match='value is a list'): with pytest.warns(UserWarning, match='value is a list'):
...@@ -218,27 +227,20 @@ def test_remap(): ...@@ -218,27 +227,20 @@ def test_remap():
np.testing.assert_equal(results['v22'], 5) np.testing.assert_equal(results['v22'], 5)
np.testing.assert_equal(results['v3'], 6) np.testing.assert_equal(results['v3'], 6)
# Case 7: `strict` must be True if `inplace` is set True # Case 7: output_map must be None if `auto_remap` is set True
with pytest.raises(ValueError): with pytest.raises(ValueError):
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True,
strict=False)
# Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'), remapping=dict(value='v_out'),
inplace=True) auto_remap=True)
# Case 9: non-strict input mapping # Case 8: allow_nonexist_keys8
pipeline = Remap( pipeline = KeyMapper(
transforms=[SumTwoValues()], transforms=[SumTwoValues()],
input_mapping=dict(num_1='a', num_2='b'), mapping=dict(num_1='a', num_2='b'),
strict=False) auto_remap=False,
allow_nonexist_keys=True)
results = pipeline(dict(a=1, b=2)) results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 3) np.testing.assert_equal(results['sum'], 3)
...@@ -246,11 +248,17 @@ def test_remap(): ...@@ -246,11 +248,17 @@ def test_remap():
results = pipeline(dict(a=1)) results = pipeline(dict(a=1))
assert np.isnan(results['sum']) assert np.isnan(results['sum'])
# Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
results = transform(dict(a=1))
# note that the original key 'a' will not be removed
assert results == dict(a=1, b=1)
# Test basic functions # Test basic functions
pipeline = Remap( pipeline = KeyMapper(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), mapping=dict(value='v_in'),
output_mapping=dict(value='v_out')) remapping=dict(value='v_out'))
# __iter__ # __iter__
for _ in pipeline: for _ in pipeline:
...@@ -263,10 +271,10 @@ def test_remap(): ...@@ -263,10 +271,10 @@ def test_remap():
def test_apply_to_multiple(): def test_apply_to_multiple():
# Case 1: apply to list in results # Case 1: apply to list in results
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='values'), mapping=dict(value='values'),
inplace=True) auto_remap=True)
results = dict(values=[1, 2]) results = dict(values=[1, 2])
results = pipeline(results) results = pipeline(results)
...@@ -274,10 +282,10 @@ def test_apply_to_multiple(): ...@@ -274,10 +282,10 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['values'], [2, 3]) np.testing.assert_equal(results['values'], [2, 3])
# Case 2: apply to multiple keys # Case 2: apply to multiple keys
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value=['v_1', 'v_2']), mapping=dict(value=['v_1', 'v_2']),
inplace=True) auto_remap=True)
results = dict(v_1=1, v_2=2) results = dict(v_1=1, v_2=2)
results = pipeline(results) results = pipeline(results)
...@@ -286,10 +294,11 @@ def test_apply_to_multiple(): ...@@ -286,10 +294,11 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['v_2'], 3) np.testing.assert_equal(results['v_2'], 3)
# Case 3: apply to multiple groups of keys # Case 3: apply to multiple groups of keys
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[SumTwoValues()], transforms=[SumTwoValues()],
input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']), mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
output_mapping=dict(sum=['a', 'b'])) remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4) results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results) results = pipeline(results)
...@@ -299,18 +308,19 @@ def test_apply_to_multiple(): ...@@ -299,18 +308,19 @@ def test_apply_to_multiple():
# Case 4: inconsistent sequence length # Case 4: inconsistent sequence length
with pytest.raises(ValueError): with pytest.raises(ValueError):
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[SumTwoValues()], transforms=[SumTwoValues()],
input_mapping=dict(num_1='list_1', num_2='list_2')) mapping=dict(num_1='list_1', num_2='list_2'),
auto_remap=False)
results = dict(list_1=[1, 2], list_2=[1, 2, 3]) results = dict(list_1=[1, 2], list_2=[1, 2, 3])
_ = pipeline(results) _ = pipeline(results)
# Case 5: share random parameter # Case 5: share random parameter
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[RandomAddToValue()], transforms=[RandomAddToValue()],
input_mapping=dict(value='values'), mapping=dict(value='values'),
inplace=True, auto_remap=True,
share_random_params=True) share_random_params=True)
results = dict(values=[0, 0]) results = dict(values=[0, 0])
...@@ -326,27 +336,27 @@ def test_randomchoice(): ...@@ -326,27 +336,27 @@ def test_randomchoice():
# Case 1: given probability # Case 1: given probability
pipeline = RandomChoice( pipeline = RandomChoice(
pipelines=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]], transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
pipeline_probs=[1.0, 0.0]) prob=[1.0, 0.0])
results = pipeline(dict(value=1)) results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0) np.testing.assert_equal(results['value'], 2.0)
# Case 2: default probability # Case 2: default probability
pipeline = RandomChoice(pipelines=[[AddToValue( pipeline = RandomChoice(transforms=[[AddToValue(
addend=1.0)], [AddToValue(addend=2.0)]]) addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1)) _ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in ApplyToMultiple # Case 3: nested RandomChoice in TransformBroadcaster
pipeline = ApplyToMultiple( pipeline = TransformBroadcaster(
transforms=[ transforms=[
RandomChoice( RandomChoice(
pipelines=[[AddToValue(addend=1.0)], transforms=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ), [AddToValue(addend=2.0)]], ),
], ],
input_mapping=dict(value='values'), mapping=dict(value='values'),
inplace=True, auto_remap=True,
share_random_params=True) share_random_params=True)
results = dict(values=[0 for _ in range(10)]) results = dict(values=[0 for _ in range(10)])
...@@ -357,10 +367,10 @@ def test_randomchoice(): ...@@ -357,10 +367,10 @@ def test_randomchoice():
def test_utils(): def test_utils():
# Test cacheable_method: normal case # Test cache_randomness: normal case
class DummyTransform(BaseTransform): class DummyTransform(BaseTransform):
@cacheable_method @cache_randomness
def func(self): def func(self):
return np.random.rand() return np.random.rand()
...@@ -373,21 +383,21 @@ def test_utils(): ...@@ -373,21 +383,21 @@ def test_utils():
with cache_random_params(transform): with cache_random_params(transform):
_ = transform({}) _ = transform({})
# Test cacheable_method: invalid function type # Test cache_randomness: invalid function type
with pytest.raises(TypeError): with pytest.raises(TypeError):
class DummyTransform(): class DummyTransform():
@cacheable_method @cache_randomness
@staticmethod @staticmethod
def func(): def func():
return np.random.rand() return np.random.rand()
# Test cacheable_method: invalid function argument list # Test cache_randomness: invalid function argument list
with pytest.raises(TypeError): with pytest.raises(TypeError):
class DummyTransform(): class DummyTransform():
@cacheable_method @cache_randomness
def func(cls): def func(cls):
return np.random.rand() return np.random.rand()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment