test_transforms_wrapper.py 18.1 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
9
10
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
                                   cache_randomness)
Yining Li's avatar
Yining Li committed
11
12
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
                                      RandomChoice, TransformBroadcaster)
13
14
15
16


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
17
    """Dummy transform to add a given addend to results['value']"""
18

19
    def __init__(self, addend=0) -> None:
20
        super().__init__()
21
        self.addend = addend
22

23
    def add(self, results, addend):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

41
42
43
    def transform(self, results):
        return self.add(results, self.addend)

plyfager's avatar
plyfager committed
44
45
46
47
48
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'addend = {self.addend}'
        return repr_str

49
50
51
52
53

@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
54
    def __init__(self, repeat=1) -> None:
55
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
56
        self.repeat = repeat
57

Yining Li's avatar
Yining Li committed
58
    @cache_randomness
59
60
61
62
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
63
64
65
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
66

plyfager's avatar
plyfager committed
67
68
69
70
71
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'repeat = {self.repeat}'
        return repr_str

72
73
74
75
76
77
78
79

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
80
81
82
83
        elif 'num_1' in results:
            results['sum'] = results['num_1']
        elif 'num_2' in results:
            results['sum'] = results['num_2']
84
85
86
87
        else:
            results['sum'] = np.nan
        return results

plyfager's avatar
plyfager committed
88
89
90
91
    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        return repr_str

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

122
    transform = RandomAddToValue()
123
124

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
125
126
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
127
128
129
130
131
132
133
134
135
136
137
138

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
139
140
141
142
143
144
145
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
146
147
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
148
            _ = transform(dict(value=0))
149

Yining Li's avatar
Yining Li committed
150
    # Case 5: apply on nested transforms
151
    transform = Compose([RandomAddToValue()])
152
153
154
155
156
157
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
158
def test_key_mapper():
159
160
161
162
163
164
165
166
167
    # Case 0: only remap
    pipeline = KeyMapper(
        transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})

    results = dict(value=0)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_out'], 1)
168
169

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
170
    pipeline = KeyMapper(
171
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
172
173
        mapping={'value': 'v_in'},
        remapping={'value': 'v_out'})
174
175
176
177
178
179
180
181
182

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
183
    pipeline = KeyMapper(
184
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
185
186
        mapping={'value': ['v_in_1', 'v_in_2']},
        remapping={'value': ['v_out_1', 'v_out_2']})
187
188
189
190
191
192
193
194
195
196
197
198
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
199
    pipeline = KeyMapper(
200
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
201
202
203
204
205
206
207
208
        mapping={'value': {
            'v1': 'v_in_1',
            'v2': 'v_in_2'
        }},
        remapping={'value': {
            'v1': 'v_out_1',
            'v2': 'v_out_2'
        }})
209
210
211
212
213
214
215
216
217
218
219
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
220
221
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
222
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
223
224
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
225
226
227
228
229
230
231
232
233
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
234
235
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
236
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
237
238
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
239
240
241
242
243
244
245
246
247
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
248
249
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
250
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
251
252
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
253
254
255
256
257
258
259
260
261
262
263
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
264
    # Case 7: output_map must be None if `auto_remap` is set True
265
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
266
        pipeline = KeyMapper(
267
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
268
269
270
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
271

Yining Li's avatar
Yining Li committed
272
273
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
274
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
275
276
277
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
278
279
280
281
282

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
283
    np.testing.assert_equal(results['sum'], 1)
284

Yining Li's avatar
Yining Li committed
285
286
287
288
289
290
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

291
292
293
294
295
296
297
298
299
300
301
    # Case 10: manually set keys ignored
    pipeline = KeyMapper(
        transforms=[SumTwoValues()],
        mapping=dict(num_1='a', num_2=...),  # num_2 (b) will be ignored
        auto_remap=False,
        # allow_nonexist_keys will not affect manually ignored keys
        allow_nonexist_keys=False)

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 1)

302
    # Test basic functions
Yining Li's avatar
Yining Li committed
303
    pipeline = KeyMapper(
304
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
305
306
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
307
308
309
310
311
312

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
plyfager's avatar
plyfager committed
313
314
315
316
317
    assert repr(pipeline) == (
        'KeyMapper(transforms = Compose(\n    ' + 'AddToValueaddend = 1' +
        '\n), mapping = {\'value\': \'v_in\'}, ' +
        'remapping = {\'value\': \'v_out\'}, auto_remap = False, ' +
        'allow_nonexist_keys = False)')
318
319


320
def test_transform_broadcaster():
321
322

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
323
    pipeline = TransformBroadcaster(
324
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
325
326
        mapping=dict(value='values'),
        auto_remap=True)
327
328
329
330
331
332
333
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
334
    pipeline = TransformBroadcaster(
335
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
336
337
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
338
339
340
341
342
343
344
345
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
346
    pipeline = TransformBroadcaster(
347
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
348
349
350
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
351
352
353
354
355
356
357

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

358
359
360
361
362
363
364
365
366
    # Case 3: apply to all keys
    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()], mapping=None, remapping=None)
    results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])

    results = pipeline(results)

    np.testing.assert_equal(results['sum'], [5, 7, 9])

367
368
    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
369
        pipeline = TransformBroadcaster(
370
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
371
372
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
373
374
375
376
377

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
378
    pipeline = TransformBroadcaster(
379
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
380
381
        mapping=dict(value='values'),
        auto_remap=True,
382
        share_random_params=True)
383
384
385
386
387
388

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    # Case 6: partial broadcasting
    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()],
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 3)

    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()],
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', ...]),
        auto_remap=False)

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    assert 'b' not in results

414
    # Test repr
plyfager's avatar
plyfager committed
415
416
417
418
419
420
    assert repr(pipeline) == (
        'TransformBroadcaster(transforms = Compose(\n' + '    SumTwoValues' +
        '\n), mapping = {\'num_1\': [\'a_1\', \'b_1\'], ' +
        '\'num_2\': [\'a_2\', \'b_2\']}, ' +
        'remapping = {\'sum\': [\'a\', Ellipsis]}, auto_remap = False, ' +
        'allow_nonexist_keys = False, share_random_params = False)')
421
422


Yining Li's avatar
Yining Li committed
423
def test_random_choice():
424
425
426

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
427
428
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
429
430
431
432

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

433
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
434
    pipeline = RandomChoice(transforms=[[AddToValue(
435
        addend=1.0)], [AddToValue(addend=2.0)]])
436
437
438

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
439
440
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
441
442
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
443
444
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
445
        ],
Yining Li's avatar
Yining Li committed
446
447
448
449
450
451
452
453
454
455
        mapping={'value': 'values'},
        auto_remap=True,
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

plyfager's avatar
plyfager committed
456
457
458
459
460
461
462
463
464
465
    # repr
    assert repr(pipeline) == (
        'TransformBroadcaster(transforms = Compose(\n' +
        '    RandomChoice(transforms = [Compose(\n' +
        '    AddToValueaddend = 1.0' + '\n), Compose(\n' +
        '    AddToValueaddend = 2.0' + '\n)]prob = None)' +
        '\n), mapping = {\'value\': \'values\'}, ' +
        'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
        'allow_nonexist_keys = False, share_random_params = True)')

Yining Li's avatar
Yining Li committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

def test_random_apply():

    # Case 1: simple use
    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 1.0)

    # Case 2: nested RandomApply in TransformBroadcaster
    pipeline = TransformBroadcaster(
        transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
        mapping={'value': 'values'},
Yining Li's avatar
Yining Li committed
482
        auto_remap=True,
483
484
485
486
487
488
489
490
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

Yining Li's avatar
Yining Li committed
491
492
493
494
    # __iter__
    for _ in pipeline:
        pass

plyfager's avatar
plyfager committed
495
496
497
498
499
500
501
502
503
    # repr
    assert repr(pipeline) == (
        'TransformBroadcaster(transforms = Compose(\n' +
        '    RandomApply(transforms = Compose(\n' +
        '    AddToValueaddend = 1' + '\n), prob = 0.5)' +
        '\n), mapping = {\'value\': \'values\'}, ' +
        'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
        'allow_nonexist_keys = False, share_random_params = True)')

504
505

def test_utils():
Yining Li's avatar
Yining Li committed
506
    # Test cache_randomness: normal case
507
508
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
509
        @cache_randomness
510
511
512
513
514
515
516
517
518
519
520
521
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
522
    # Test cache_randomness: invalid function type
523
524
    with pytest.raises(TypeError):

525
        class DummyTransform(BaseTransform):
526

Yining Li's avatar
Yining Li committed
527
            @cache_randomness
528
529
530
531
            @staticmethod
            def func():
                return np.random.rand()

532
533
534
            def transform(self, results):
                return results

Yining Li's avatar
Yining Li committed
535
    # Test cache_randomness: invalid function argument list
536
537
    with pytest.raises(TypeError):

538
        class DummyTransform(BaseTransform):
539

Yining Li's avatar
Yining Li committed
540
            @cache_randomness
541
542
            def func(cls):
                return np.random.rand()
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: invalid mixture with cache_randomness
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            @cache_randomness
            def func(self):
                pass

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: raise error in cache_random_params
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            def transform(self, results):
                return results

        transform = DummyTransform()
        with cache_random_params(transform):
            pass

    # Test avoid_cache_randomness: non-inheritable
    @avoid_cache_randomness
    class DummyBaseTransform(BaseTransform):

        def transform(self, results):
            return results

    class DummyTransform(DummyBaseTransform):
        pass

    transform = DummyTransform()
    with cache_random_params(transform):
        pass