"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "a5dd01bb74d0d5775b6af177a8d077f7fb634947"
Commit 59eaefeb authored by liyining's avatar liyining Committed by zhouzaida
Browse files

[Feature] Support partial mapping by manually marking keys as ignored

parent 3b494a13
...@@ -160,7 +160,10 @@ pipeline = [ ...@@ -160,7 +160,10 @@ pipeline = [
pipeline = [ pipeline = [
... ...
dict(type='KeyMapper', dict(type='KeyMapper',
mapping={'img': 'gt_img'}, # 将 "gt_img" 字段映射至 "img" 字段 mapping={
'img': 'gt_img', # 将 "gt_img" 字段映射至 "img" 字段
'mask': ..., # 不使用原始数据中的 "mask" 字段。即对于被包装的数据变换,数据中不包含 "mask" 字段
},
auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段 auto_remap=True, # 在完成变换后,将 "img" 重映射回 "gt_img" 字段
transforms=[ transforms=[
# 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可 # 在 `RandomFlip` 变换类中,我们只需要操作 "img" 字段即可
...@@ -237,6 +240,29 @@ pipeline = [ ...@@ -237,6 +240,29 @@ pipeline = [
] ]
``` ```
在多目标扩展的 `mapping` 设置中,我们同样可以使用 `...` 来忽略指定的原始字段。如以下例子中,被包裹的 `RandomCrop` 会对字段 `"img"` 中的图像进行裁剪,并且在字段 `"img_shape"` 存在时更新剪裁后的图像大小。如果我们希望同时对两个图像字段 `"lq"``"gt"` 进行相同的随机裁剪,但只更新一次 `"img_shape"` 字段,可以通过例子中的方式实现:
```python
pipeline = [
dict(type='TransformBroadcaster',
mapping={
'img': ['lq', 'gt'],
'img_shape': ['img_shape', ...],
},
# 在完成变换后,将 "img" 和 "img_shape" 字段重映射回原先的字段
auto_remap=True,
# 是否在对各目标的变换中共享随机变量
# 更多介绍参加后续章节(随机变量共享)
share_random_params=True,
transforms=[
# `RandomCrop` 类中会操作 "img" 和 "img_shape" 字段。若 "img_shape" 空缺,
# 则只操作 "img"
dict(type='RandomCrop'),
])
]
```
2. 应用于一个字段的一组目标 2. 应用于一个字段的一组目标
假设我们需要将数据变换应用于 `"images"` 字段,该字段为一个图像组成的 list。 假设我们需要将数据变换应用于 `"images"` 字段,该字段为一个图像组成的 list。
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Dict from typing import Dict, Optional
class BaseTransform(metaclass=ABCMeta): class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Dict: def __call__(self, results: Dict) -> Optional[Dict]:
return self.transform(results) return self.transform(results)
@abstractmethod @abstractmethod
def transform(self, results: Dict) -> Dict: def transform(self, results: Dict) -> Optional[Dict]:
"""The transform function. All subclass of BaseTransform should """The transform function. All subclass of BaseTransform should
override this method. override this method.
......
...@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -151,7 +151,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# key2counter stores the usage number of each cache_randomness. This is # key2counter stores the usage number of each cache_randomness. This is
# used to check that any cache_randomness is invoked once during processing # used to check that any cache_randomness is invoked once during processing
# on data sample. # on data sample.
key2counter = defaultdict(int) key2counter: dict = defaultdict(int)
def _add_invoke_counter(obj, method_name): def _add_invoke_counter(obj, method_name):
method = getattr(obj, method_name) method = getattr(obj, method_name)
...@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -212,7 +212,7 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
# Store the original method and init the counter # Store the original method and init the counter
if hasattr(t, '_methods_with_randomness'): if hasattr(t, '_methods_with_randomness'):
setattr(t, 'transform', _add_invoke_checker(t, 'transform')) setattr(t, 'transform', _add_invoke_checker(t, 'transform'))
for name in t._methods_with_randomness: for name in getattr(t, '_methods_with_randomness'):
setattr(t, name, _add_invoke_counter(t, name)) setattr(t, name, _add_invoke_counter(t, name))
def _end_cache(t: BaseTransform): def _end_cache(t: BaseTransform):
...@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -221,20 +221,21 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
return return
# Remove cache enabled flag # Remove cache enabled flag
del t._cache_enabled delattr(t, '_cache_enabled')
if hasattr(t, '_cache'): if hasattr(t, '_cache'):
del t._cache delattr(t, '_cache')
# Restore the original method # Restore the original method
if hasattr(t, '_methods_with_randomness'): if hasattr(t, '_methods_with_randomness'):
for name in t._methods_with_randomness: for name in getattr(t, '_methods_with_randomness'):
key = f'{id(t)}.{name}' key = f'{id(t)}.{name}'
setattr(t, name, key2method[key]) setattr(t, name, key2method[key])
key_transform = f'{id(t)}.transform' key_transform = f'{id(t)}.transform'
setattr(t, 'transform', key2method[key_transform]) setattr(t, 'transform', key2method[key_transform])
def _apply(t: BaseTransform, func: Callable[[BaseTransform], None]): def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
if isinstance(t, BaseTransform): if isinstance(t, BaseTransform):
func(t) func(t)
if isinstance(t, Iterable): if isinstance(t, Iterable):
......
...@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness ...@@ -13,8 +13,14 @@ from .utils import cache_random_params, cache_randomness
# Define type of transform or transform config # Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]] Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator for required but missing keys in results # Indicator of keys marked by KeyMapper._map_input, which means ignoring the
NotInResults = object() # marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative # Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation. # implementation.
...@@ -23,7 +29,7 @@ try: ...@@ -23,7 +29,7 @@ try:
except ImportError: except ImportError:
from contextlib import contextmanager from contextlib import contextmanager
@contextmanager @contextmanager # type: ignore
def nullcontext(resource=None): def nullcontext(resource=None):
try: try:
yield resource yield resource
...@@ -54,7 +60,7 @@ class Compose(BaseTransform): ...@@ -54,7 +60,7 @@ class Compose(BaseTransform):
if not isinstance(transforms, list): if not isinstance(transforms, list):
transforms = [transforms] transforms = [transforms]
self.transforms = [] self.transforms: List = []
for transform in transforms: for transform in transforms:
if isinstance(transform, dict): if isinstance(transform, dict):
transform = TRANSFORMS.build(transform) transform = TRANSFORMS.build(transform)
...@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform): ...@@ -137,6 +143,7 @@ class KeyMapper(BaseTransform):
>>> dict(type='Normalize'), >>> dict(type='Normalize'),
>>> ]) >>> ])
>>> ] >>> ]
>>> # Example 2: Collect and structure multiple items >>> # Example 2: Collect and structure multiple items
>>> pipeline = [ >>> pipeline = [
>>> # The inner field 'imgs' will be a dict with keys 'img_src' >>> # The inner field 'imgs' will be a dict with keys 'img_src'
...@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform): ...@@ -151,6 +158,22 @@ class KeyMapper(BaseTransform):
>>> img_tar='img2')), >>> img_tar='img2')),
>>> transforms=...) >>> transforms=...)
>>> ] >>> ]
>>> # Example 3: Manually set ignored keys by "..."
>>> pipeline = [
>>> ...
>>> dict(type='KeyMapper',
>>> mapping={
>>> # map outer key "gt_img" to inner key "img"
>>> 'img': 'gt_img',
>>> # ignore outer key "mask"
>>> 'mask': ...,
>>> },
>>> transforms=[
>>> dict(type='RandomFlip'),
>>> ])
>>> ...
>>> ]
""" """
def __init__(self, def __init__(self,
...@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform): ...@@ -185,20 +208,25 @@ class KeyMapper(BaseTransform):
"""Allow easy iteration over the transform sequence.""" """Allow easy iteration over the transform sequence."""
return iter(self.transforms) return iter(self.transforms)
def map_input(self, data: Dict, mapping: Dict) -> Dict[str, Any]: def _map_input(self, data: Dict,
mapping: Optional[Dict]) -> Dict[str, Any]:
"""KeyMapper inputs for the wrapped transforms by gathering and """KeyMapper inputs for the wrapped transforms by gathering and
renaming data items according to the mapping. renaming data items according to the mapping.
Args: Args:
data (dict): The original input data data (dict): The original input data
mapping (dict): The input key mapping. See the document of mapping (dict, optional): The input key mapping. See the document
``mmcv.transforms.wrappers.KeyMapper`` for details. of ``mmcv.transforms.wrappers.KeyMapper`` for details. In
set None, return the input data directly.
Returns: Returns:
dict: The input data with remapped keys. This will be the actual dict: The input data with remapped keys. This will be the actual
input of the wrapped pipeline. input of the wrapped pipeline.
""" """
if mapping is None:
return data.copy()
def _map(data, m): def _map(data, m):
if isinstance(m, dict): if isinstance(m, dict):
# m is a dict {inner_key:outer_key, ...} # m is a dict {inner_key:outer_key, ...}
...@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform): ...@@ -210,17 +238,17 @@ class KeyMapper(BaseTransform):
# transforms. # transforms.
return m.__class__(_map(data, e) for e in m) return m.__class__(_map(data, e) for e in m)
# allow manually mark a key to be ignored by ...
if m is ...:
return IgnoreKey
# m is an outer_key # m is an outer_key
if self.allow_nonexist_keys: if self.allow_nonexist_keys:
return data.get(m, NotInResults) return data.get(m, IgnoreKey)
else: else:
return data.get(m) return data.get(m)
collected = _map(data, mapping) collected = _map(data, mapping)
collected = {
k: v
for k, v in collected.items() if v is not NotInResults
}
# Retain unmapped items # Retain unmapped items
inputs = data.copy() inputs = data.copy()
...@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform): ...@@ -228,19 +256,26 @@ class KeyMapper(BaseTransform):
return inputs return inputs
def map_output(self, data: Dict, remapping: Dict) -> Dict[str, Any]: def _map_output(self, data: Dict,
remapping: Optional[Dict]) -> Dict[str, Any]:
"""KeyMapper outputs from the wrapped transforms by gathering and """KeyMapper outputs from the wrapped transforms by gathering and
renaming data items according to the remapping. renaming data items according to the remapping.
Args: Args:
data (dict): The output of the wrapped pipeline. data (dict): The output of the wrapped pipeline.
remapping (dict): The output key mapping. See the document of remapping (dict, optional): The output key mapping. See the
``mmcv.transforms.wrappers.KeyMapper`` for details. document of ``mmcv.transforms.wrappers.KeyMapper`` for
details. If ``remapping is None``, no key mapping will be
applied but only remove the special token ``IgnoreKey``.
Returns: Returns:
dict: The output with remapped keys. dict: The output with remapped keys.
""" """
# Remove ``IgnoreKey``
if remapping is None:
return {k: v for k, v in data.items() if v is not IgnoreKey}
def _map(data, m): def _map(data, m):
if isinstance(m, dict): if isinstance(m, dict):
assert isinstance(data, dict) assert isinstance(data, dict)
...@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform): ...@@ -257,21 +292,44 @@ class KeyMapper(BaseTransform):
results.update(_map(d_i, m_i)) results.update(_map(d_i, m_i))
return results return results
if m is IgnoreKey:
return {}
return {m: data} return {m: data}
# Note that unmapped items are not retained, which is different from # Note that unmapped items are not retained, which is different from
# the behavior in map_input. This is to avoid original data items # the behavior in _map_input. This is to avoid original data items
# being overwritten by intermediate namesakes # being overwritten by intermediate namesakes
return _map(data, remapping) return _map(data, remapping)
def transform(self, results: Dict) -> Dict: def _apply_transforms(self, inputs: Dict) -> Dict:
inputs = results """Apply ``self.transforms``.
if self.mapping:
inputs = self.map_input(inputs, self.mapping) Note that the special token ``IgnoreKey`` will be invisible to
``self.transforms``, but not removed in this method. It will be
eventually removed in :func:``self._map_output``.
"""
results = inputs.copy()
inputs = {k: v for k, v in inputs.items() if v is not IgnoreKey}
outputs = self.transforms(inputs) outputs = self.transforms(inputs)
if self.remapping: if outputs is None:
outputs = self.map_output(outputs, self.remapping) raise ValueError(
f'Transforms wrapped by {self.__class__.__name__} should '
'not return None.')
results.update(outputs) # type: ignore
return results
def transform(self, results: Dict) -> Dict:
"""Apply mapping, wrapped transforms and remapping."""
# Apply mapping
inputs = self._map_input(results, self.mapping)
# Apply wrapped transforms
outputs = self._apply_transforms(inputs)
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs) results.update(outputs)
return results return results
...@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper): ...@@ -314,7 +372,8 @@ class TransformBroadcaster(KeyMapper):
example. example.
Examples: Examples:
>>> # Example 1: >>> # Example 1: Broadcast to enumerated keys, each contains a single
>>> # data element
>>> pipeline = [ >>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper): ...@@ -333,7 +392,8 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'), >>> dict(type='Normalize'),
>>> ]) >>> ])
>>> ] >>> ]
>>> # Example 2:
>>> # Example 2: Broadcast to keys that contains data sequences
>>> pipeline = [ >>> pipeline = [
>>> dict(type='LoadImageFromFile', key='lq'), # low-quality img >>> dict(type='LoadImageFromFile', key='lq'), # low-quality img
>>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img >>> dict(type='LoadImageFromFile', key='gt'), # ground-truth img
...@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper): ...@@ -351,6 +411,24 @@ class TransformBroadcaster(KeyMapper):
>>> dict(type='Normalize'), >>> dict(type='Normalize'),
>>> ]) >>> ])
>>> ] >>> ]
>>> Example 3: Set ignored keys in broadcasting
>>> pipeline = [
>>> dict(type='TransformBroadcaster',
>>> # Broadcast the wrapped transforms to multiple images
>>> # 'lq' and 'gt, but only update 'img_shape' once
>>> mapping={
>>> 'img': ['lq', 'gt'],
>>> 'img_shape': ['img_shape', ...],
>>> },
>>> auto_remap=True,
>>> share_random_params=True,
>>> transforms=[
>>> # `RandomCrop` will modify the field "img",
>>> # and optionally update "img_shape" if it exists
>>> dict(type='RandomCrop'),
>>> ])
>>> ]
""" """
def __init__(self, def __init__(self,
...@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper): ...@@ -366,17 +444,23 @@ class TransformBroadcaster(KeyMapper):
self.share_random_params = share_random_params self.share_random_params = share_random_params
def scatter_sequence(self, data: Dict) -> List[Dict]: def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
# infer split number from input # infer split number from input
seq_len = None seq_len = 0
key_rep = None key_rep = None
if self.mapping: if self.mapping:
keys = self.mapping.keys() keys = self.mapping.keys()
else: else:
keys = data.keys() keys = data.keys()
for key in keys: for key in keys:
assert isinstance(data[key], Sequence) assert isinstance(data[key], Sequence)
if seq_len is not None: if seq_len:
if len(data[key]) != seq_len: if len(data[key]) != seq_len:
raise ValueError('Got inconsistent sequence length: ' raise ValueError('Got inconsistent sequence length: '
f'{seq_len} ({key_rep}) vs. ' f'{seq_len} ({key_rep}) vs. '
...@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper): ...@@ -385,6 +469,8 @@ class TransformBroadcaster(KeyMapper):
seq_len = len(data[key]) seq_len = len(data[key])
key_rep = key key_rep = key
assert seq_len > 0, 'Fail to get the number of broadcasting targets'
scatters = [] scatters = []
for i in range(seq_len): for i in range(seq_len):
scatter = data.copy() scatter = data.copy()
...@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper): ...@@ -394,13 +480,13 @@ class TransformBroadcaster(KeyMapper):
return scatters return scatters
def transform(self, results: Dict): def transform(self, results: Dict):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping # Apply input remapping
inputs = results inputs = self._map_input(results, self.mapping)
if self.mapping:
inputs = self.map_input(inputs, self.mapping)
# Scatter sequential inputs into a list # Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs) input_scatters = self.scatter_sequence(inputs)
# Control random parameter sharing with a context manager # Control random parameter sharing with a context manager
if self.share_random_params: if self.share_random_params:
...@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper): ...@@ -410,20 +496,21 @@ class TransformBroadcaster(KeyMapper):
# by all data items. # by all data items.
ctx = cache_random_params ctx = cache_random_params
else: else:
ctx = nullcontext ctx = nullcontext # type: ignore
with ctx(self.transforms): with ctx(self.transforms):
outputs = [self.transforms(_input) for _input in inputs] output_scatters = [
self._apply_transforms(_input) for _input in input_scatters
]
# Collate output scatters (list of dict to dict of list) # Collate output scatters (list of dict to dict of list)
outputs = { outputs = {
key: [_output[key] for _output in outputs] key: [_output[key] for _output in output_scatters]
for key in outputs[0] for key in output_scatters[0]
} }
# Apply output remapping # Apply remapping
if self.remapping: outputs = self._map_output(outputs, self.remapping)
outputs = self.map_output(outputs, self.remapping)
results.update(outputs) results.update(outputs)
return results return results
...@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform): ...@@ -473,11 +560,13 @@ class RandomChoice(BaseTransform):
return iter(self.transforms) return iter(self.transforms)
@cache_randomness @cache_randomness
def random_pipeline_index(self): def random_pipeline_index(self) -> int:
"""Return a random transform index."""
indices = np.arange(len(self.transforms)) indices = np.arange(len(self.transforms))
return np.random.choice(indices, p=self.prob) return np.random.choice(indices, p=self.prob)
def transform(self, results): def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly choose a transform to apply."""
idx = self.random_pipeline_index() idx = self.random_pipeline_index()
return self.transforms[idx](results) return self.transforms[idx](results)
...@@ -512,10 +601,14 @@ class RandomApply(BaseTransform): ...@@ -512,10 +601,14 @@ class RandomApply(BaseTransform):
return iter(self.transforms) return iter(self.transforms)
@cache_randomness @cache_randomness
def random_apply(self): def random_apply(self) -> bool:
"""Return a random bool value indicating whether apply the transform.
"""
return np.random.rand() < self.prob return np.random.rand() < self.prob
def transform(self, results: Dict) -> Dict: def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly apply the transform."""
if self.random_apply(): if self.random_apply():
results = self.transforms(results) return self.transforms(results)
return results else:
return results
...@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform): ...@@ -67,6 +67,10 @@ class SumTwoValues(BaseTransform):
def transform(self, results): def transform(self, results):
if 'num_1' in results and 'num_2' in results: if 'num_1' in results and 'num_2' in results:
results['sum'] = results['num_1'] + results['num_2'] results['sum'] = results['num_1'] + results['num_2']
elif 'num_1' in results:
results['sum'] = results['num_1']
elif 'num_2' in results:
results['sum'] = results['num_2']
else: else:
results['sum'] = np.nan results['sum'] = np.nan
return results return results
...@@ -262,7 +266,7 @@ def test_key_mapper(): ...@@ -262,7 +266,7 @@ def test_key_mapper():
np.testing.assert_equal(results['sum'], 3) np.testing.assert_equal(results['sum'], 3)
results = pipeline(dict(a=1)) results = pipeline(dict(a=1))
assert np.isnan(results['sum']) np.testing.assert_equal(results['sum'], 1)
# Case 9: use wrapper as a transform # Case 9: use wrapper as a transform
transform = KeyMapper(mapping=dict(b='a'), auto_remap=False) transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
...@@ -270,6 +274,17 @@ def test_key_mapper(): ...@@ -270,6 +274,17 @@ def test_key_mapper():
# note that the original key 'a' will not be removed # note that the original key 'a' will not be removed
assert results == dict(a=1, b=1) assert results == dict(a=1, b=1)
# Case 10: manually set keys ignored
pipeline = KeyMapper(
transforms=[SumTwoValues()],
mapping=dict(num_1='a', num_2=...), # num_2 (b) will be ignored
auto_remap=False,
# allow_nonexist_keys will not affect manually ignored keys
allow_nonexist_keys=False)
results = pipeline(dict(a=1, b=2))
np.testing.assert_equal(results['sum'], 1)
# Test basic functions # Test basic functions
pipeline = KeyMapper( pipeline = KeyMapper(
transforms=[AddToValue(addend=1)], transforms=[AddToValue(addend=1)],
...@@ -353,6 +368,31 @@ def test_transform_broadcaster(): ...@@ -353,6 +368,31 @@ def test_transform_broadcaster():
np.testing.assert_equal(results['values'][0], results['values'][1]) np.testing.assert_equal(results['values'][0], results['values'][1])
# Case 6: partial broadcasting
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
remapping=dict(sum=['a', 'b']),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 3)
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()],
mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
remapping=dict(sum=['a', ...]),
auto_remap=False)
results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
results = pipeline(results)
np.testing.assert_equal(results['a'], 3)
assert 'b' not in results
# Test repr # Test repr
_ = str(pipeline) _ = str(pipeline)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment