test_transforms_wrapper.py 13.2 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
Yining Li's avatar
Yining Li committed
9
from mmcv.transforms.utils import cache_random_params, cache_randomness
Yining Li's avatar
Yining Li committed
10
11
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
                                      RandomChoice, TransformBroadcaster)
12
13
14
15


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
16
    """Dummy transform to add a given addend to results['value']"""
17

18
    def __init__(self, addend=0) -> None:
19
        super().__init__()
20
        self.addend = addend
21

22
    def add(self, results, addend):
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

40
41
42
43
44
45
46
47
    def transform(self, results):
        return self.add(results, self.addend)


@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
48
    def __init__(self, repeat=1) -> None:
49
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
50
        self.repeat = repeat
51

Yining Li's avatar
Yining Li committed
52
    @cache_randomness
53
54
55
56
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
57
58
59
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
        else:
            results['sum'] = np.nan
        return results


def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

103
    transform = RandomAddToValue()
104
105

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
106
107
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
108
109
110
111
112
113
114
115
116
117
118
119

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
120
121
122
123
124
125
126
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
127
128
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
129
            _ = transform(dict(value=0))
130

Yining Li's avatar
Yining Li committed
131
    # Case 5: apply on nested transforms
132
    transform = Compose([RandomAddToValue()])
133
134
135
136
137
138
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
139
def test_key_mapper():
140
141

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
142
    pipeline = KeyMapper(
143
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
144
145
        mapping={'value': 'v_in'},
        remapping={'value': 'v_out'})
146
147
148
149
150
151
152
153
154

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
155
    pipeline = KeyMapper(
156
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
157
158
        mapping={'value': ['v_in_1', 'v_in_2']},
        remapping={'value': ['v_out_1', 'v_out_2']})
159
160
161
162
163
164
165
166
167
168
169
170
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
171
    pipeline = KeyMapper(
172
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
173
174
175
176
177
178
179
180
        mapping={'value': {
            'v1': 'v_in_1',
            'v2': 'v_in_2'
        }},
        remapping={'value': {
            'v1': 'v_out_1',
            'v2': 'v_out_2'
        }})
181
182
183
184
185
186
187
188
189
190
191
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
192
193
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
194
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
195
196
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
197
198
199
200
201
202
203
204
205
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
206
207
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
208
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
209
210
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
211
212
213
214
215
216
217
218
219
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
220
221
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
222
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
223
224
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
225
226
227
228
229
230
231
232
233
234
235
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
236
    # Case 7: output_map must be None if `auto_remap` is set True
237
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
238
        pipeline = KeyMapper(
239
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
240
241
242
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
243

Yining Li's avatar
Yining Li committed
244
245
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
246
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
247
248
249
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
250
251
252
253
254
255
256

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
    assert np.isnan(results['sum'])

Yining Li's avatar
Yining Li committed
257
258
259
260
261
262
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

263
    # Test basic functions
Yining Li's avatar
Yining Li committed
264
    pipeline = KeyMapper(
265
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
266
267
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
268
269
270
271
272
273
274
275
276
277
278
279

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
    _ = str(pipeline)


def test_apply_to_multiple():

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
280
    pipeline = TransformBroadcaster(
281
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
282
283
        mapping=dict(value='values'),
        auto_remap=True)
284
285
286
287
288
289
290
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
291
    pipeline = TransformBroadcaster(
292
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
293
294
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
295
296
297
298
299
300
301
302
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
303
    pipeline = TransformBroadcaster(
304
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
305
306
307
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
308
309
310
311
312
313
314
315
316

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
317
        pipeline = TransformBroadcaster(
318
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
319
320
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
321
322
323
324
325

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
326
    pipeline = TransformBroadcaster(
327
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
328
329
        mapping=dict(value='values'),
        auto_remap=True,
330
        share_random_params=True)
331
332
333
334
335
336
337
338
339
340

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

    # Test repr
    _ = str(pipeline)


Yining Li's avatar
Yining Li committed
341
def test_random_choice():
342
343
344

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
345
346
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
347
348
349
350

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

351
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
352
    pipeline = RandomChoice(transforms=[[AddToValue(
353
        addend=1.0)], [AddToValue(addend=2.0)]])
354
355
356

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
357
358
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
359
360
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
361
362
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
363
        ],
Yining Li's avatar
Yining Li committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        mapping={'value': 'values'},
        auto_remap=True,
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))


def test_random_apply():

    # Case 1: simple use
    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 1.0)

    # Case 2: nested RandomApply in TransformBroadcaster
    pipeline = TransformBroadcaster(
        transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
        mapping={'value': 'values'},
Yining Li's avatar
Yining Li committed
390
        auto_remap=True,
391
392
393
394
395
396
397
398
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

Yining Li's avatar
Yining Li committed
399
400
401
402
    # __iter__
    for _ in pipeline:
        pass

403
404

def test_utils():
Yining Li's avatar
Yining Li committed
405
    # Test cache_randomness: normal case
406
407
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
408
        @cache_randomness
409
410
411
412
413
414
415
416
417
418
419
420
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
421
    # Test cache_randomness: invalid function type
422
423
424
425
    with pytest.raises(TypeError):

        class DummyTransform():

Yining Li's avatar
Yining Li committed
426
            @cache_randomness
427
428
429
430
            @staticmethod
            def func():
                return np.random.rand()

Yining Li's avatar
Yining Li committed
431
    # Test cache_randomness: invalid function argument list
432
433
434
435
    with pytest.raises(TypeError):

        class DummyTransform():

Yining Li's avatar
Yining Li committed
436
            @cache_randomness
437
438
            def func(cls):
                return np.random.rand()