"examples/vscode:/vscode.git/clone" did not exist on "844899440a02815c02131febc9fdeb435185ec2a"
Commit 6534efd6 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

[Fix] Add @cacheable_method in transform wrapper `RandomChoice` (#1807)

* add @cacheable_methd in RandomChoice

* RandomChoice add __iter__() and fix unittest
parent 2f85d781
...@@ -139,15 +139,17 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]): ...@@ -139,15 +139,17 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key = f'{id(t)}.{name}' key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]: if key2counter[key] != key2counter[key_transform]:
raise RuntimeError( raise RuntimeError(
'The cacheable method should be called once and only' 'The cacheable method should be called once and only '
f'once during processing one data sample. {t} got' f'once during processing one data sample. {t} got '
f'unmatched number of {key2counter[key]} ({name}) vs' f'unmatched number of {key2counter[key]} ({name}) vs '
f'{key2counter[key_transform]} (data samples)') f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key]) setattr(t, name, key2method[key])
setattr(t, 'transform', key2method[key_transform]) setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable], def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]): func: Callable[[BaseTransform], None]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
if isinstance(t, BaseTransform): if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'): if hasattr(t, '_cacheable_methods'):
func(t) func(t)
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import mmcv import mmcv
from .base import BaseTransform from .base import BaseTransform
from .builder import TRANSFORMS from .builder import TRANSFORMS
from .utils import cache_random_params from .utils import cache_random_params, cacheable_method
# Indicator for required but missing keys in results # Indicator for required but missing keys in results
NotInResults = object() NotInResults = object()
...@@ -452,6 +452,14 @@ class RandomChoice(BaseTransform): ...@@ -452,6 +452,14 @@ class RandomChoice(BaseTransform):
self.pipeline_probs = pipeline_probs self.pipeline_probs = pipeline_probs
self.pipelines = [Compose(transforms) for transforms in pipelines] self.pipelines = [Compose(transforms) for transforms in pipelines]
def __iter__(self):
return iter(self.pipelines)
@cacheable_method
def random_pipeline_index(self):
indices = np.arange(len(self.pipelines))
return np.random.choice(indices, p=self.pipeline_probs)
def transform(self, results): def transform(self, results):
pipeline = np.random.choice(self.pipelines, p=self.pipeline_probs) idx = self.random_pipeline_index()
return pipeline(results) return self.pipelines[idx](results)
...@@ -13,18 +13,13 @@ from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice, ...@@ -13,18 +13,13 @@ from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice,
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class AddToValue(BaseTransform): class AddToValue(BaseTransform):
"""Dummy transform to test transform wrappers.""" """Dummy transform to add a given addend to results['value']"""
def __init__(self, constant_addend=0, use_random_addend=False) -> None: def __init__(self, addend=0) -> None:
super().__init__() super().__init__()
self.constant_addend = constant_addend self.addend = addend
self.use_random_addend = use_random_addend
@cacheable_method def add(self, results, addend):
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
augend = results['value'] augend = results['value']
if isinstance(augend, list): if isinstance(augend, list):
...@@ -39,14 +34,27 @@ class AddToValue(BaseTransform): ...@@ -39,14 +34,27 @@ class AddToValue(BaseTransform):
return {k: _add_to_value(v, addend) for k, v in augend.items()} return {k: _add_to_value(v, addend) for k, v in augend.items()}
return augend + addend return augend + addend
if self.use_random_addend:
addend = self.get_random_addend()
else:
addend = self.constant_addend
results['value'] = _add_to_value(results['value'], addend) results['value'] = _add_to_value(results['value'], addend)
return results return results
def transform(self, results):
return self.add(results, self.addend)
@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']"""
def __init__(self) -> None:
super().__init__(addend=None)
@cacheable_method
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
return self.add(results, addend=self.get_random_addend())
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class SumTwoValues(BaseTransform): class SumTwoValues(BaseTransform):
...@@ -89,11 +97,11 @@ def test_compose(): ...@@ -89,11 +97,11 @@ def test_compose():
def test_cache_random_parameters(): def test_cache_random_parameters():
transform = AddToValue(use_random_addend=True) transform = RandomAddToValue()
# Case 1: cache random parameters # Case 1: cache random parameters
assert hasattr(AddToValue, '_cacheable_methods') assert hasattr(RandomAddToValue, '_cacheable_methods')
assert 'get_random_addend' in AddToValue._cacheable_methods assert 'get_random_addend' in RandomAddToValue._cacheable_methods
with cache_random_params(transform): with cache_random_params(transform):
results_1 = transform(dict(value=0)) results_1 = transform(dict(value=0))
...@@ -112,7 +120,7 @@ def test_cache_random_parameters(): ...@@ -112,7 +120,7 @@ def test_cache_random_parameters():
_ = transform.get_random_addend() _ = transform.get_random_addend()
# Case 4: apply on nested transforms # Case 4: apply on nested transforms
transform = Compose([AddToValue(use_random_addend=True)]) transform = Compose([RandomAddToValue()])
with cache_random_params(transform): with cache_random_params(transform):
results_1 = transform(dict(value=0)) results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0)) results_2 = transform(dict(value=0))
...@@ -123,7 +131,7 @@ def test_remap(): ...@@ -123,7 +131,7 @@ def test_remap():
# Case 1: simple remap # Case 1: simple remap
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out')) output_mapping=dict(value='v_out'))
...@@ -136,7 +144,7 @@ def test_remap(): ...@@ -136,7 +144,7 @@ def test_remap():
# Case 2: collecting list # Case 2: collecting list
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']), input_mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2'])) output_mapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
...@@ -152,7 +160,7 @@ def test_remap(): ...@@ -152,7 +160,7 @@ def test_remap():
# Case 3: collecting dict # Case 3: collecting dict
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2'))) output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
...@@ -168,7 +176,7 @@ def test_remap(): ...@@ -168,7 +176,7 @@ def test_remap():
# Case 4: collecting list with inplace mode # Case 4: collecting list with inplace mode
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']), input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True) inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
...@@ -182,7 +190,7 @@ def test_remap(): ...@@ -182,7 +190,7 @@ def test_remap():
# Case 5: collecting dict with inplace mode # Case 5: collecting dict with inplace mode
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')), input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True) inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2) results = dict(value=0, v_in_1=1, v_in_2=2)
...@@ -196,7 +204,7 @@ def test_remap(): ...@@ -196,7 +204,7 @@ def test_remap():
# Case 6: nested collection with inplace mode # Case 6: nested collection with inplace mode
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]), input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True) inplace=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4) results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
...@@ -213,7 +221,7 @@ def test_remap(): ...@@ -213,7 +221,7 @@ def test_remap():
# Case 7: `strict` must be True if `inplace` is set True # Case 7: `strict` must be True if `inplace` is set True
with pytest.raises(ValueError): with pytest.raises(ValueError):
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=2)], transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']), input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True, inplace=True,
strict=False) strict=False)
...@@ -221,7 +229,7 @@ def test_remap(): ...@@ -221,7 +229,7 @@ def test_remap():
# Case 8: output_map must be None if `inplace` is set True # Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError): with pytest.raises(ValueError):
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'), output_mapping=dict(value='v_out'),
inplace=True) inplace=True)
...@@ -240,7 +248,7 @@ def test_remap(): ...@@ -240,7 +248,7 @@ def test_remap():
# Test basic functions # Test basic functions
pipeline = Remap( pipeline = Remap(
transforms=[AddToValue(constant_addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'), input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out')) output_mapping=dict(value='v_out'))
...@@ -256,7 +264,7 @@ def test_apply_to_multiple(): ...@@ -256,7 +264,7 @@ def test_apply_to_multiple():
# Case 1: apply to list in results # Case 1: apply to list in results
pipeline = ApplyToMultiple( pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value='values'), input_mapping=dict(value='values'),
inplace=True) inplace=True)
results = dict(values=[1, 2]) results = dict(values=[1, 2])
...@@ -267,7 +275,7 @@ def test_apply_to_multiple(): ...@@ -267,7 +275,7 @@ def test_apply_to_multiple():
# Case 2: apply to multiple keys # Case 2: apply to multiple keys
pipeline = ApplyToMultiple( pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)], transforms=[AddToValue(addend=1)],
input_mapping=dict(value=['v_1', 'v_2']), input_mapping=dict(value=['v_1', 'v_2']),
inplace=True) inplace=True)
results = dict(v_1=1, v_2=2) results = dict(v_1=1, v_2=2)
...@@ -300,11 +308,10 @@ def test_apply_to_multiple(): ...@@ -300,11 +308,10 @@ def test_apply_to_multiple():
# Case 5: share random parameter # Case 5: share random parameter
pipeline = ApplyToMultiple( pipeline = ApplyToMultiple(
transforms=[AddToValue(use_random_addend=True)], transforms=[RandomAddToValue()],
input_mapping=dict(value='values'), input_mapping=dict(value='values'),
inplace=True, inplace=True,
share_random_params=True, share_random_params=True)
)
results = dict(values=[0, 0]) results = dict(values=[0, 0])
results = pipeline(results) results = pipeline(results)
...@@ -319,19 +326,35 @@ def test_randomchoice(): ...@@ -319,19 +326,35 @@ def test_randomchoice():
# Case 1: given probability # Case 1: given probability
pipeline = RandomChoice( pipeline = RandomChoice(
pipelines=[[AddToValue(constant_addend=1.0)], pipelines=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
[AddToValue(constant_addend=2.0)]],
pipeline_probs=[1.0, 0.0]) pipeline_probs=[1.0, 0.0])
results = pipeline(dict(value=1)) results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0) np.testing.assert_equal(results['value'], 2.0)
# Case 1: default probability # Case 2: default probability
pipeline = RandomChoice(pipelines=[[AddToValue( pipeline = RandomChoice(pipelines=[[AddToValue(
constant_addend=1.0)], [AddToValue(constant_addend=2.0)]]) addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1)) _ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in ApplyToMultiple
pipeline = ApplyToMultiple(
transforms=[
RandomChoice(
pipelines=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
input_mapping=dict(value='values'),
inplace=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
results = pipeline(results)
# check share_random_params=True works so that all values are same
values = results['values']
assert all(map(lambda x: x == values[0], values))
def test_utils(): def test_utils():
# Test cacheable_method: normal case # Test cacheable_method: normal case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment