Commit 6534efd6 authored by Yining Li's avatar Yining Li Committed by zhouzaida
Browse files

[Fix] Add @cacheable_method in transform wrapper `RandomChoice` (#1807)

* add @cacheable_methd in RandomChoice

* RandomChoice add __iter__() and fix unittest
parent 2f85d781
......@@ -139,15 +139,17 @@ def cache_random_params(transforms: Union[BaseTransform, Iterable]):
key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only'
f'once during processing one data sample. {t} got'
f'unmatched number of {key2counter[key]} ({name}) vs'
'The cacheable method should be called once and only '
f'once during processing one data sample. {t} got '
f'unmatched number of {key2counter[key]} ({name}) vs '
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key])
setattr(t, 'transform', key2method[key_transform])
def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
# Note that BaseTransform and Iterable are not mutually exclusive,
# e.g. Compose, RandomChoice
if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'):
func(t)
......
......@@ -8,7 +8,7 @@ import numpy as np
import mmcv
from .base import BaseTransform
from .builder import TRANSFORMS
from .utils import cache_random_params
from .utils import cache_random_params, cacheable_method
# Indicator for required but missing keys in results
NotInResults = object()
......@@ -452,6 +452,14 @@ class RandomChoice(BaseTransform):
self.pipeline_probs = pipeline_probs
self.pipelines = [Compose(transforms) for transforms in pipelines]
def __iter__(self):
return iter(self.pipelines)
@cacheable_method
def random_pipeline_index(self):
indices = np.arange(len(self.pipelines))
return np.random.choice(indices, p=self.pipeline_probs)
def transform(self, results):
pipeline = np.random.choice(self.pipelines, p=self.pipeline_probs)
return pipeline(results)
idx = self.random_pipeline_index()
return self.pipelines[idx](results)
......@@ -13,18 +13,13 @@ from mmcv.transforms.wrappers import (ApplyToMultiple, Compose, RandomChoice,
@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
"""Dummy transform to test transform wrappers."""
"""Dummy transform to add a given addend to results['value']"""
def __init__(self, constant_addend=0, use_random_addend=False) -> None:
def __init__(self, addend=0) -> None:
super().__init__()
self.constant_addend = constant_addend
self.use_random_addend = use_random_addend
self.addend = addend
@cacheable_method
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
def add(self, results, addend):
augend = results['value']
if isinstance(augend, list):
......@@ -39,14 +34,27 @@ class AddToValue(BaseTransform):
return {k: _add_to_value(v, addend) for k, v in augend.items()}
return augend + addend
if self.use_random_addend:
addend = self.get_random_addend()
else:
addend = self.constant_addend
results['value'] = _add_to_value(results['value'], addend)
return results
def transform(self, results):
return self.add(results, self.addend)
@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
"""Dummy transform to add a random addend to results['value']"""
def __init__(self) -> None:
super().__init__(addend=None)
@cacheable_method
def get_random_addend(self):
return np.random.rand()
def transform(self, results):
return self.add(results, addend=self.get_random_addend())
@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
......@@ -89,11 +97,11 @@ def test_compose():
def test_cache_random_parameters():
transform = AddToValue(use_random_addend=True)
transform = RandomAddToValue()
# Case 1: cache random parameters
assert hasattr(AddToValue, '_cacheable_methods')
assert 'get_random_addend' in AddToValue._cacheable_methods
assert hasattr(RandomAddToValue, '_cacheable_methods')
assert 'get_random_addend' in RandomAddToValue._cacheable_methods
with cache_random_params(transform):
results_1 = transform(dict(value=0))
......@@ -112,7 +120,7 @@ def test_cache_random_parameters():
_ = transform.get_random_addend()
# Case 4: apply on nested transforms
transform = Compose([AddToValue(use_random_addend=True)])
transform = Compose([RandomAddToValue()])
with cache_random_params(transform):
results_1 = transform(dict(value=0))
results_2 = transform(dict(value=0))
......@@ -123,7 +131,7 @@ def test_remap():
# Case 1: simple remap
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
......@@ -136,7 +144,7 @@ def test_remap():
# Case 2: collecting list
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
output_mapping=dict(value=['v_out_1', 'v_out_2']))
results = dict(value=0, v_in_1=1, v_in_2=2)
......@@ -152,7 +160,7 @@ def test_remap():
# Case 3: collecting dict
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
output_mapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
results = dict(value=0, v_in_1=1, v_in_2=2)
......@@ -168,7 +176,7 @@ def test_remap():
# Case 4: collecting list with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
......@@ -182,7 +190,7 @@ def test_remap():
# Case 5: collecting dict with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
inplace=True)
results = dict(value=0, v_in_1=1, v_in_2=2)
......@@ -196,7 +204,7 @@ def test_remap():
# Case 6: nested collection with inplace mode
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
inplace=True)
results = dict(value=0, v1=1, v21=2, v22=3, v3=4)
......@@ -213,7 +221,7 @@ def test_remap():
# Case 7: `strict` must be True if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(constant_addend=2)],
transforms=[AddToValue(addend=2)],
input_mapping=dict(value=['v_in_1', 'v_in_2']),
inplace=True,
strict=False)
......@@ -221,7 +229,7 @@ def test_remap():
# Case 8: output_map must be None if `inplace` is set True
with pytest.raises(ValueError):
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'),
inplace=True)
......@@ -240,7 +248,7 @@ def test_remap():
# Test basic functions
pipeline = Remap(
transforms=[AddToValue(constant_addend=1)],
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='v_in'),
output_mapping=dict(value='v_out'))
......@@ -256,7 +264,7 @@ def test_apply_to_multiple():
# Case 1: apply to list in results
pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)],
transforms=[AddToValue(addend=1)],
input_mapping=dict(value='values'),
inplace=True)
results = dict(values=[1, 2])
......@@ -267,7 +275,7 @@ def test_apply_to_multiple():
# Case 2: apply to multiple keys
pipeline = ApplyToMultiple(
transforms=[AddToValue(constant_addend=1)],
transforms=[AddToValue(addend=1)],
input_mapping=dict(value=['v_1', 'v_2']),
inplace=True)
results = dict(v_1=1, v_2=2)
......@@ -300,11 +308,10 @@ def test_apply_to_multiple():
# Case 5: share random parameter
pipeline = ApplyToMultiple(
transforms=[AddToValue(use_random_addend=True)],
transforms=[RandomAddToValue()],
input_mapping=dict(value='values'),
inplace=True,
share_random_params=True,
)
share_random_params=True)
results = dict(values=[0, 0])
results = pipeline(results)
......@@ -319,19 +326,35 @@ def test_randomchoice():
# Case 1: given probability
pipeline = RandomChoice(
pipelines=[[AddToValue(constant_addend=1.0)],
[AddToValue(constant_addend=2.0)]],
pipelines=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
pipeline_probs=[1.0, 0.0])
results = pipeline(dict(value=1))
np.testing.assert_equal(results['value'], 2.0)
# Case 1: default probability
# Case 2: default probability
pipeline = RandomChoice(pipelines=[[AddToValue(
constant_addend=1.0)], [AddToValue(constant_addend=2.0)]])
addend=1.0)], [AddToValue(addend=2.0)]])
_ = pipeline(dict(value=1))
# Case 3: nested RandomChoice in ApplyToMultiple
pipeline = ApplyToMultiple(
transforms=[
RandomChoice(
pipelines=[[AddToValue(addend=1.0)],
[AddToValue(addend=2.0)]], ),
],
input_mapping=dict(value='values'),
inplace=True,
share_random_params=True)
results = dict(values=[0 for _ in range(10)])
results = pipeline(results)
# check share_random_params=True works so that all values are same
values = results['values']
assert all(map(lambda x: x == values[0], values))
def test_utils():
# Test cacheable_method: normal case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment