"vscode:/vscode.git/clone" did not exist on "3dcc5e9a5ac3e736da09f93cf64940599b1e1674"
Commit e2ca0733 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

Refactor base transform (#1830)

* rename cacheable_method to cache_randomness

* refactor transform wrappers and update docs

* fix all_nonexist_keys

* fix lint

* rename transform wrappers
parent 0a5b4125
......@@ -150,18 +150,18 @@ pipeline = [
变换包装是一种特殊的数据变换类,他们本身并不操作数据字典中的图像、标签等信息,而是对其中定义的数据变换的行为进行增强。
### 字段映射(Remap
### 字段映射(KeyMapper
字段映射包装(`Remap`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。
字段映射包装(`KeyMapper`)用于对数据字典中的字段进行映射。例如,一般的图像处理变换都从数据字典中的 `"img"` 字段获得值。但有些时候,我们希望这些变换处理数据字典中其他字段中的图像,比如 `"gt_img"` 字段。
如果配合注册器和配置文件使用的话,在配置文件中数据集的 `pipeline` 中如下例使用字段映射包装:
```python
pipeline = [
...
dict(type='Remap',
input_mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
inplace=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
dict(type='KeyMapper',
mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段
auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
dict(type='RandomFlip'),
......@@ -182,7 +182,7 @@ pipeline = [
pipeline = [
...
dict(type='RandomChoice',
pipelines=[
transforms=[
[
dict(type='Posterize', bits=4),
dict(type='Rotate', angle=30.)
......@@ -192,17 +192,17 @@ pipeline = [
dict(type='Rotate', angle=30)
], # 第二种随机变换组合
],
pipeline_probs=[0.4, 0.6] # 两种随机变换组合各自的选用概率
prob=[0.4, 0.6] # 两种随机变换组合各自的选用概率
)
...
]
```
### 多目标扩展(ApplyToMultiple
### 多目标扩展(TransformBroadcaster
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `Remap` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`ApplyToMultiple`)。
通常,一个数据变换类只会从一个固定的字段读取操作目标。虽然我们也可以使用 `KeyMapper` 来改变读取的字段,但无法将变换一次性应用于多个字段的数据。为了实现这一功能,我们需要借助多目标扩展包装(`TransformBroadcaster`)。
多目标扩展包装(`ApplyToMultiple`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。
多目标扩展包装(`TransformBroadcaster`)有两个用法,一是将数据变换作用于指定的多个字段,二是将数据变换作用于某个字段下的一组目标中。
1. 应用于多个字段
......@@ -210,11 +210,11 @@ pipeline = [
```python
pipeline = [
dict(type='ApplyToMultiple',
dict(type='TransformBroadcaster',
# 分别应用于 "lq" 和 "gt" 两个字段,并将二者应设置 "img" 字段
input_mapping={'img': ['lq', 'gt']},
mapping={'img': ['lq', 'gt']},
# 在完成变换后,将 "img" 字段重映射回原先的字段
inplace=True,
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_param=True,
......@@ -231,11 +231,11 @@ pipeline = [
```python
pipeline = [
dict(type='ApplyToMultiple',
dict(type='TransformBroadcaster',
# 将 "images" 字段下的每张图片映射至 "img" 字段
input_mapping={'img': 'images'},
mapping={'img': 'images'},
# 在完成变换后,将 "img" 字段下的图片重映射回 "images" 字段的列表中
inplace=True,
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
share_random_param=True,
transforms=[
......@@ -245,12 +245,12 @@ pipeline = [
]
```
`ApplyToMultiple` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
`TransformBroadcaster` 中,我们提供了 `share_random_param` 选项来支持在多次数据变换中共享随机状态。例如,在超分辨率任务中,我们希望将随机变换**同步**作用于低分辨率图像和原始图像。如果我们希望在自定义的数据变换类中使用这一功能,我们需要在类中标注哪些随机变量是支持共享的。
以上文中的 `MyFlip` 为例,我们希望以一定的概率随机执行翻转:
```python
from mmcv.transforms.utils import cacheable_method
from mmcv.transforms.utils import cache_randomness
@TRANSFORMS.register_module()
class MyRandomFlip(BaseTransform):
......@@ -259,7 +259,7 @@ class MyRandomFlip(BaseTransform):
self.prob = prob
self.direction = direction
@cacheable_method # 标注该方法的输出为可共享的随机变量
@cache_randomness # 标注该方法的输出为可共享的随机变量
def do_flip(self):
flip = True if random.random() > self.prob else False
return flip
......@@ -271,4 +271,4 @@ class MyRandomFlip(BaseTransform):
return results
```
通过 `cacheable_method` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `ApplyToMultiple` 对多个目标的变换中,这一变量的值都会保持一致。
通过 `cache_randomness` 装饰器,方法返回值 `flip` 被标注为一个支持共享的随机变量。进而,在 `TransformBroadcaster` 对多个目标的变换中,这一变量的值都会保持一致。
......@@ -4,24 +4,24 @@ from .loading import LoadAnnotation, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomFlip, RandomGrayscale, RandomMultiscaleResize,
RandomResize, Resize)
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap
from .wrappers import Compose, KeyMapper, RandomChoice, TransformBroadcaster
try:
import torch # noqa: F401
except ImportError:
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
__all__ = [
'TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap',
'LoadImageFromFile', 'LoadAnnotation', 'Normalize', 'Resize', 'Pad',
'ToTensor', 'to_tensor', 'ImageToTensor', 'RandomFlip',
'RandomMultiscaleResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize'
'TRANSFORMS', 'TransformBroadcaster', 'Compose', 'RandomChoice',
'KeyMapper', 'LoadImageFromFile', 'LoadAnnotation', 'Normalize',
'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomMultiscaleResize', 'CenterCrop',
'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize'
]
......@@ -11,17 +11,17 @@ from typing import Callable, Union
from .base import BaseTransform
class cacheable_method:
"""Decorator that marks a method of a transform class as a cacheable
method.
class cache_randomness:
"""Decorator that marks the method with random return value(s) in a
transform class.
This decorator is usually used together with the context-manager
:func`:cache_random_params`. In this context, a cacheable method will
:func`:cache_random_params`. In this context, a decorated method will
cache its return value(s) at the first time of being invoked, and always
return the cached values when being invoked again.
.. note::
Only a instance method can be decorated as a cacheable_method.
Only an instance method can be decorated by ``cache_randomness``.
"""
def __init__(self, func):
......@@ -29,12 +29,12 @@ class cacheable_method:
# Check `func` is to be bound as an instance method
if not inspect.isfunction(func):
raise TypeError('Unsupport callable to decorate with'
'@cacheable_method.')
'@cache_randomness.')
func_args = inspect.getfullargspec(func).args
if len(func_args) == 0 or func_args[0] != 'self':
raise TypeError(
'@cacheable_method should only be used to decorate '
'instance methods (the first argument is `self`).')
'@cache_randomness should only be used to decorate '
'instance methods (the first argument is ``self``).')
functools.update_wrapper(self, func)
self.func = func
......@@ -42,24 +42,24 @@ class cacheable_method:
def __set_name__(self, owner, name):
# Maintain a record of decorated methods in the class
if not hasattr(owner, '_cacheable_methods'):
setattr(owner, '_cacheable_methods', [])
owner._cacheable_methods.append(self.__name__)
if not hasattr(owner, '_methods_with_randomness'):
setattr(owner, '_methods_with_randomness', [])
owner._methods_with_randomness.append(self.__name__)
def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated
# by cacheable_method
# by cache_randomness
instance = self.instance_ref()
name = self.__name__
# Check the flag `self._cache_enabled`, which should be
# set by the contextmanagers like `cache_random_parameters`
# Check the flag ``self._cache_enabled``, which should be
# set by the contextmanagers like ``cache_random_parameters```
cache_enabled = getattr(instance, '_cache_enabled', False)
if cache_enabled:
# Initialize the cache of the transform instances. The flag
# `cache_enabled` is set by contextmanagers like
# `cache_random_params`.
# ``cache_enabled``` is set by contextmanagers like
# ``cache_random_params```.
if not hasattr(instance, '_cache'):
setattr(instance, '_cache', {})
......@@ -81,13 +81,13 @@ class cacheable_method:
@contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of cacheable methods in
transforms.
"""Context-manager that enables the cache of return values of methods
decorated by ``cache_randomness`` in transforms.
In this mode, cacheable methods will cache their return values on the
In this mode, decorated methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow
to apply random transforms in a deterministic way. For example, apply same
transforms on multiple examples. See `cacheable_method` for more
transforms on multiple examples. See ``cache_randomness`` for more
information.
Args:
......@@ -99,12 +99,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# ones. These methods will be restituted when exiting the context.
key2method = dict()
# key2counter stores the usage number of each cacheable_method. This is
# used to check that any cacheable_method is invoked once during processing
# key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing
# on data sample.
key2counter = defaultdict(int)
def _add_counter(obj, method_name):
def _add_invoke_counter(obj, method_name):
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
......@@ -116,15 +116,43 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return wrapped
def _add_invoke_checker(obj, method_name):
# check that the method in _methods_with_randomness has been
# invoked at most once
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method
@functools.wraps(method)
def wrapped(*args, **kwargs):
# clear counter
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
key2counter[key] = 0
output = method(*args, **kwargs)
for name in obj._methods_with_randomness:
key = f'{id(obj)}.{name}'
if key2counter[key] > 1:
raise RuntimeError(
'The method decorated by ``cache_randomness`` should '
'be invoked at most once during processing one data '
f'sample. The method {name} of {obj} has been invoked'
f' {key2counter[key]} times.')
return output
return wrapped
def _start_cache(t: BaseTransform):
# Set cache enabled flag
setattr(t, '_cache_enabled', True)
# Store the original method and init the counter
if hasattr(t, '_cacheable_methods'):
setattr(t, 'transform', _add_counter(t, 'transform'))
for name in t._cacheable_methods:
setattr(t, name, _add_counter(t, name))
if hasattr(t, '_methods_with_randomness'):
setattr(t, 'transform', _add_invoke_checker(t, 'transform'))
for name in t._methods_with_randomness:
setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform):
# Remove cache enabled flag
......@@ -133,17 +161,12 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
del t._cache
# Restore the original method
if hasattr(t, '_cacheable_methods'):
key_transform = f'{id(t)}.transform'
for name in t._cacheable_methods:
if hasattr(t, '_methods_with_randomness'):
for name in t._methods_with_randomness:
key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only '
f'once during processing one data sample. {t} got '
f'unmatched number of {key2counter[key]} ({name}) vs '
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key])
key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable],
......@@ -151,7 +174,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'):
if hasattr(t, '_methods_with_randomness'):
func(t)
if isinstance(t, Iterable):
for _t in t:
......
This diff is collapsed.
......@@ -6,9 +6,9 @@ import pytest
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
from mmcv.transforms.utils import cache_random_params, cacheable_method
from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice,
Remap)
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomChoice,
TransformBroadcaster)
@TRANSFORMS.register_module()
......@@ -45,15 +45,18 @@ class AddToValue(BaseTransform):
class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']"""
def __init__(self) -> None:
def __init__(self, repeat=1) -> None:
super().__init__(addend=None)
self.repeat = repeat
@cacheable_method
@cache_randomness
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
return self.add(results, addend=self.get_random_addend())
for _ in range(self.repeat):
results = self.add(results, addend=self.get_random_addend())
return results
@TRANSFORMS.register_module()
......@@ -100,8 +103,8 @@ def test_cache_random_parameters():
transform = RandomAddToValue()
# Case 1: cache random parameters
assert hasattr(RandomAddToValue, '_cacheable_methods')
assert 'get_random_addend' in RandomAddToValue._cacheable_methods
assert hasattr(RandomAddToValue, '_methods_with_randomness')
assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
with cache_random_params(transform):
results_1 = transform(dict(value=0))
......@@ -114,12 +117,18 @@ def test_cache_random_parameters():
with pytest.raises(AssertionError):
np.testing.assert_equal(results_1['value'], results_2['value'])
# Case 3: invalid use of cacheable methods
# Case 3: allow to invoke random method 0 times
transform = RandomAddToValue(repeat=0)
with cache_random_params(transform):
_ = transform(dict(value=0))
# Case 4: NOT allow to invoke random method >1 times
transform = RandomAddToValue(repeat=2)
with pytest.raises(RuntimeError):
with cache_random_params(transform):
_ = transform.get_random_addend()
_ = transform(dict(value=0))
# Case 4: apply on nested transforms
# Case 5: apply on nested transforms
transform = Compose([RandomAddToValue()])
with cache_random_params(transform):
results_1 = transform(dict(value=0))
......@@ -127,13 +136,13 @@ def test_cache_random_parameters():
np.testing.assert_equal(results_1['value'], results_2['value'])
def test_remap():
def test_apply_to_mapped():
# Case 1: simple remap
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
results = dict(value=0, v_in=1)
results = pipeline(results)
......@@ -143,10 +152,10 @@ def test_remap():
np.testing.assert_equal(results['v_out'], 2)
# Case 2: collecting list
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2']))
mapping=dict(value=['v_in_1', 'v_in_2']),
remapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -159,10 +168,10 @@ def test_remap():
np.testing.assert_equal(results['v_out_2'], 4)
# Case 3: collecting dict
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
remapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
......@@ -174,11 +183,11 @@ def test_remap():
np.testing.assert_equal(results['v_out_1'], 3)
np.testing.assert_equal(results['v_out_2'], 4)
# Case 4: collecting list with inplace mode
pipeline = Remap(
# Case 4: collecting list with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True)
mapping=dict(value=['v_in_1', 'v_in_2']),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -188,11 +197,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 5: collecting dict with inplace mode
pipeline = Remap(
# Case 5: collecting dict with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True)
mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
auto_remap=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
with pytest.warns(UserWarning, match='value is a dict'):
......@@ -202,11 +211,11 @@ def test_remap():
np.testing.assert_equal(results['v_in_1'], 3)
np.testing.assert_equal(results['v_in_2'], 4)
# Case 6: nested collection with inplace mode
pipeline = Remap(
# Case 6: nested collection with auto_remap mode
pipeline = KeyMapper(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True)
mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
auto_remap=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
with pytest.warns(UserWarning, match='value is a list'):
......@@ -218,27 +227,20 @@ def test_remap():
np.testing.assert_equal(results['v22'], 5)
np.testing.assert_equal(results['v3'], 6)
# Case 7: `strict` must be True if `inplace` is set True
# Case 7: output_map must be None if `auto_remap` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True,
strict=False)
# Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'),
inplace=True)
mapping=dict(value='v_in'),
remapping=dict(value='v_out'),
auto_remap=True)
# Case 9: non-strict input mapping
pipeline = Remap(
# Case 8: allow_nonexist_keys8
pipeline = KeyMapper(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='a', num_2='b'),
strict=False)
mapping=dict(num_1='a', num_2='b'),
auto_remap=False,
allow_nonexist_keys=True)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 3)
......@@ -246,11 +248,17 @@ def test_remap():
results = pipeline(dict(a=1))
assert np.isnan(results['sum'])
# Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
results = transform(dict(a=1))
# note that the original key 'a' will not be removed
assert results == dict(a=1, b=1)
# Test basic functions
pipeline = Remap(
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
mapping=dict(value='v_in'),
remapping=dict(value='v_out'))
# __iter__
for _ in pipeline:
......@@ -263,10 +271,10 @@ def test_remap():
def test_apply_to_multiple():
# Case 1: apply to list in results
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='values'),
inplace=True)
mapping=dict(value='values'),
auto_remap=True)
results = dict(values=[1, 2])
results = pipeline(results)
......@@ -274,10 +282,10 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['values'], [2, 3])
# Case 2: apply to multiple keys
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[AddToValue(addend=1)],
input_mapping=dict(value=['v_1', 'v_2']),
inplace=True)
mapping=dict(value=['v_1', 'v_2']),
auto_remap=True)
results = dict(v_1=1, v_2=2)
results = pipeline(results)
......@@ -286,10 +294,11 @@ def test_apply_to_multiple():
np.testing.assert_equal(results['v_2'], 3)
# Case 3: apply to multiple groups of keys
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
input_mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
output_mapping=dict(sum=['a', 'b']))
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
......@@ -299,18 +308,19 @@ def test_apply_to_multiple():
# Case 4: inconsistent sequence length
with pytest.raises(ValueError):
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
input_mapping=dict(num_1='list_1', num_2='list_2'))
mapping=dict(num_1='list_1', num_2='list_2'),
auto_remap=False)
results = dict(list_1=[1, 2], list_2=[1, 2, 3])
_ = pipeline(results)
# Case 5: share random parameter
pipeline = ApplyToMultiple(
pipeline = TransformBroadcaster(
transforms=[RandomAddToValue()],
input_mapping=dict(value='values'),
inplace=True,
mapping=dict(value='values'),
auto_remap=True,
share_random_params=True)
results = dict(values=[0, 0])
......@@ -326,27 +336,27 @@ def test_randomchoice():
# Case 1: given probability
pipeline = RandomChoice(
pipelines=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
pipeline_probs=[1.0, 0.0])
transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
prob=[1.0, 0.0])
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
# Case 2: default probability
pipeline = RandomChoice(pipelines=[[AddToValue(
pipeline = RandomChoice(transforms=[[AddToValue(
addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in ApplyToMultiple
pipeline = ApplyToMultiple(
# Case 3: nested RandomChoice in TransformBroadcaster
pipeline = TransformBroadcaster(
transforms=[
RandomChoice(
pipelines=[[AddToValue(addend=1.0)],
transforms=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
input_mapping=dict(value='values'),
inplace=True,
mapping=dict(value='values'),
auto_remap=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
......@@ -357,10 +367,10 @@ def test_randomchoice():
def test_utils():
# Test cacheable_method: normal case
# Test cache_randomness: normal case
class DummyTransform(BaseTransform):
@cacheable_method
@cache_randomness
def func(self):
return np.random.rand()
......@@ -373,21 +383,21 @@ def test_utils():
with cache_random_params(transform):
_ = transform({})
# Test cacheable_method: invalid function type
# Test cache_randomness: invalid function type
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
@cache_randomness
@staticmethod
def func():
return np.random.rand()
# Test cacheable_method: invalid function argument list
# Test cache_randomness: invalid function argument list
with pytest.raises(TypeError):
class DummyTransform():
@cacheable_method
@cache_randomness
def func(cls):
return np.random.rand()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment