test_transforms_wrapper.py 16.3 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
9
10
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
                                   cache_randomness)
Yining Li's avatar
Yining Li committed
11
12
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
                                      RandomChoice, TransformBroadcaster)
13
14
15
16


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
17
    """Dummy transform to add a given addend to results['value']"""
18

19
    def __init__(self, addend=0) -> None:
20
        super().__init__()
21
        self.addend = addend
22

23
    def add(self, results, addend):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

41
42
43
44
45
46
47
48
    def transform(self, results):
        return self.add(results, self.addend)


@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
49
    def __init__(self, repeat=1) -> None:
50
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
51
        self.repeat = repeat
52

Yining Li's avatar
Yining Li committed
53
    @cache_randomness
54
55
56
57
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
58
59
60
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
61

62
63
64
65
66
67
68
69

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
70
71
72
73
        elif 'num_1' in results:
            results['sum'] = results['num_1']
        elif 'num_2' in results:
            results['sum'] = results['num_2']
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        else:
            results['sum'] = np.nan
        return results


def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

108
    transform = RandomAddToValue()
109
110

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
111
112
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
113
114
115
116
117
118
119
120
121
122
123
124

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
125
126
127
128
129
130
131
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
132
133
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
134
            _ = transform(dict(value=0))
135

Yining Li's avatar
Yining Li committed
136
    # Case 5: apply on nested transforms
137
    transform = Compose([RandomAddToValue()])
138
139
140
141
142
143
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
144
def test_key_mapper():
145
146
147
148
149
150
151
152
153
    # Case 0: only remap
    pipeline = KeyMapper(
        transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})

    results = dict(value=0)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_out'], 1)
154
155

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
156
    pipeline = KeyMapper(
157
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
158
159
        mapping={'value': 'v_in'},
        remapping={'value': 'v_out'})
160
161
162
163
164
165
166
167
168

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
169
    pipeline = KeyMapper(
170
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
171
172
        mapping={'value': ['v_in_1', 'v_in_2']},
        remapping={'value': ['v_out_1', 'v_out_2']})
173
174
175
176
177
178
179
180
181
182
183
184
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
185
    pipeline = KeyMapper(
186
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
187
188
189
190
191
192
193
194
        mapping={'value': {
            'v1': 'v_in_1',
            'v2': 'v_in_2'
        }},
        remapping={'value': {
            'v1': 'v_out_1',
            'v2': 'v_out_2'
        }})
195
196
197
198
199
200
201
202
203
204
205
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
206
207
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
208
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
209
210
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
211
212
213
214
215
216
217
218
219
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
220
221
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
222
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
223
224
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
225
226
227
228
229
230
231
232
233
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
234
235
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
236
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
237
238
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
239
240
241
242
243
244
245
246
247
248
249
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
250
    # Case 7: output_map must be None if `auto_remap` is set True
251
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
252
        pipeline = KeyMapper(
253
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
254
255
256
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
257

Yining Li's avatar
Yining Li committed
258
259
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
260
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
261
262
263
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
264
265
266
267
268

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
269
    np.testing.assert_equal(results['sum'], 1)
270

Yining Li's avatar
Yining Li committed
271
272
273
274
275
276
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

277
278
279
280
281
282
283
284
285
286
287
    # Case 10: manually set keys ignored
    pipeline = KeyMapper(
        transforms=[SumTwoValues()],
        mapping=dict(num_1='a', num_2=...),  # num_2 (b) will be ignored
        auto_remap=False,
        # allow_nonexist_keys will not affect manually ignored keys
        allow_nonexist_keys=False)

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 1)

288
    # Test basic functions
Yining Li's avatar
Yining Li committed
289
    pipeline = KeyMapper(
290
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
291
292
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
293
294
295
296
297
298
299
300
301

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
    _ = str(pipeline)


302
def test_transform_broadcaster():
303
304

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
305
    pipeline = TransformBroadcaster(
306
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
307
308
        mapping=dict(value='values'),
        auto_remap=True)
309
310
311
312
313
314
315
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
316
    pipeline = TransformBroadcaster(
317
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
318
319
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
320
321
322
323
324
325
326
327
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
328
    pipeline = TransformBroadcaster(
329
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
330
331
332
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
333
334
335
336
337
338
339

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

340
341
342
343
344
345
346
347
348
    # Case 3: apply to all keys
    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()], mapping=None, remapping=None)
    results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])

    results = pipeline(results)

    np.testing.assert_equal(results['sum'], [5, 7, 9])

349
350
    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
351
        pipeline = TransformBroadcaster(
352
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
353
354
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
355
356
357
358
359

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
360
    pipeline = TransformBroadcaster(
361
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
362
363
        mapping=dict(value='values'),
        auto_remap=True,
364
        share_random_params=True)
365
366
367
368
369
370

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    # Case 6: partial broadcasting
    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()],
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', ...]),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 3)

    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()],
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', ...]),
        auto_remap=False)

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    assert 'b' not in results

396
397
398
399
    # Test repr
    _ = str(pipeline)


Yining Li's avatar
Yining Li committed
400
def test_random_choice():
401
402
403

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
404
405
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
406
407
408
409

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

410
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
411
    pipeline = RandomChoice(transforms=[[AddToValue(
412
        addend=1.0)], [AddToValue(addend=2.0)]])
413
414
415

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
416
417
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
418
419
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
420
421
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
422
        ],
Yining Li's avatar
Yining Li committed
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        mapping={'value': 'values'},
        auto_remap=True,
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))


def test_random_apply():

    # Case 1: simple use
    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 1.0)

    # Case 2: nested RandomApply in TransformBroadcaster
    pipeline = TransformBroadcaster(
        transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
        mapping={'value': 'values'},
Yining Li's avatar
Yining Li committed
449
        auto_remap=True,
450
451
452
453
454
455
456
457
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

Yining Li's avatar
Yining Li committed
458
459
460
461
    # __iter__
    for _ in pipeline:
        pass

462
463

def test_utils():
Yining Li's avatar
Yining Li committed
464
    # Test cache_randomness: normal case
465
466
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
467
        @cache_randomness
468
469
470
471
472
473
474
475
476
477
478
479
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
480
    # Test cache_randomness: invalid function type
481
482
    with pytest.raises(TypeError):

483
        class DummyTransform(BaseTransform):
484

Yining Li's avatar
Yining Li committed
485
            @cache_randomness
486
487
488
489
            @staticmethod
            def func():
                return np.random.rand()

490
491
492
            def transform(self, results):
                return results

Yining Li's avatar
Yining Li committed
493
    # Test cache_randomness: invalid function argument list
494
495
    with pytest.raises(TypeError):

496
        class DummyTransform(BaseTransform):
497

Yining Li's avatar
Yining Li committed
498
            @cache_randomness
499
500
            def func(cls):
                return np.random.rand()
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: invalid mixture with cache_randomness
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            @cache_randomness
            def func(self):
                pass

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: raise error in cache_random_params
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            def transform(self, results):
                return results

        transform = DummyTransform()
        with cache_random_params(transform):
            pass

    # Test avoid_cache_randomness: non-inheritable
    @avoid_cache_randomness
    class DummyBaseTransform(BaseTransform):

        def transform(self, results):
            return results

    class DummyTransform(DummyBaseTransform):
        pass

    transform = DummyTransform()
    with cache_random_params(transform):
        pass