test_transforms_wrapper.py 14.4 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
9
10
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
                                   cache_randomness)
Yining Li's avatar
Yining Li committed
11
12
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
                                      RandomChoice, TransformBroadcaster)
13
14
15
16


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
17
    """Dummy transform to add a given addend to results['value']"""
18

19
    def __init__(self, addend=0) -> None:
20
        super().__init__()
21
        self.addend = addend
22

23
    def add(self, results, addend):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

41
42
43
44
45
46
47
48
    def transform(self, results):
        return self.add(results, self.addend)


@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
49
    def __init__(self, repeat=1) -> None:
50
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
51
        self.repeat = repeat
52

Yining Li's avatar
Yining Li committed
53
    @cache_randomness
54
55
56
57
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
58
59
60
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
        else:
            results['sum'] = np.nan
        return results


def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

104
    transform = RandomAddToValue()
105
106

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
107
108
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
109
110
111
112
113
114
115
116
117
118
119
120

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
121
122
123
124
125
126
127
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
128
129
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
130
            _ = transform(dict(value=0))
131

Yining Li's avatar
Yining Li committed
132
    # Case 5: apply on nested transforms
133
    transform = Compose([RandomAddToValue()])
134
135
136
137
138
139
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
140
def test_key_mapper():
141
142

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
143
    pipeline = KeyMapper(
144
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
145
146
        mapping={'value': 'v_in'},
        remapping={'value': 'v_out'})
147
148
149
150
151
152
153
154
155

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
156
    pipeline = KeyMapper(
157
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
158
159
        mapping={'value': ['v_in_1', 'v_in_2']},
        remapping={'value': ['v_out_1', 'v_out_2']})
160
161
162
163
164
165
166
167
168
169
170
171
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
172
    pipeline = KeyMapper(
173
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
174
175
176
177
178
179
180
181
        mapping={'value': {
            'v1': 'v_in_1',
            'v2': 'v_in_2'
        }},
        remapping={'value': {
            'v1': 'v_out_1',
            'v2': 'v_out_2'
        }})
182
183
184
185
186
187
188
189
190
191
192
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
193
194
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
195
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
196
197
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
198
199
200
201
202
203
204
205
206
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
207
208
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
209
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
210
211
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
212
213
214
215
216
217
218
219
220
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
221
222
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
223
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
224
225
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
226
227
228
229
230
231
232
233
234
235
236
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
237
    # Case 7: output_map must be None if `auto_remap` is set True
238
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
239
        pipeline = KeyMapper(
240
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
241
242
243
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
244

Yining Li's avatar
Yining Li committed
245
246
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
247
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
248
249
250
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
251
252
253
254
255
256
257

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
    assert np.isnan(results['sum'])

Yining Li's avatar
Yining Li committed
258
259
260
261
262
263
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

264
    # Test basic functions
Yining Li's avatar
Yining Li committed
265
    pipeline = KeyMapper(
266
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
267
268
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
269
270
271
272
273
274
275
276
277

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
    _ = str(pipeline)


278
def test_transform_broadcaster():
279
280

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
281
    pipeline = TransformBroadcaster(
282
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
283
284
        mapping=dict(value='values'),
        auto_remap=True)
285
286
287
288
289
290
291
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
292
    pipeline = TransformBroadcaster(
293
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
294
295
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
296
297
298
299
300
301
302
303
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
304
    pipeline = TransformBroadcaster(
305
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
306
307
308
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
309
310
311
312
313
314
315
316
317

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
318
        pipeline = TransformBroadcaster(
319
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
320
321
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
322
323
324
325
326

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
327
    pipeline = TransformBroadcaster(
328
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
329
330
        mapping=dict(value='values'),
        auto_remap=True,
331
        share_random_params=True)
332
333
334
335
336
337
338
339
340
341

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

    # Test repr
    _ = str(pipeline)


Yining Li's avatar
Yining Li committed
342
def test_random_choice():
343
344
345

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
346
347
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
348
349
350
351

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

352
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
353
    pipeline = RandomChoice(transforms=[[AddToValue(
354
        addend=1.0)], [AddToValue(addend=2.0)]])
355
356
357

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
358
359
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
360
361
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
362
363
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
364
        ],
Yining Li's avatar
Yining Li committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        mapping={'value': 'values'},
        auto_remap=True,
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))


def test_random_apply():

    # Case 1: simple use
    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 1.0)

    # Case 2: nested RandomApply in TransformBroadcaster
    pipeline = TransformBroadcaster(
        transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
        mapping={'value': 'values'},
Yining Li's avatar
Yining Li committed
391
        auto_remap=True,
392
393
394
395
396
397
398
399
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

Yining Li's avatar
Yining Li committed
400
401
402
403
    # __iter__
    for _ in pipeline:
        pass

404
405

def test_utils():
Yining Li's avatar
Yining Li committed
406
    # Test cache_randomness: normal case
407
408
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
409
        @cache_randomness
410
411
412
413
414
415
416
417
418
419
420
421
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
422
    # Test cache_randomness: invalid function type
423
424
    with pytest.raises(TypeError):

425
        class DummyTransform(BaseTransform):
426

Yining Li's avatar
Yining Li committed
427
            @cache_randomness
428
429
430
431
            @staticmethod
            def func():
                return np.random.rand()

432
433
434
            def transform(self, results):
                return results

Yining Li's avatar
Yining Li committed
435
    # Test cache_randomness: invalid function argument list
436
437
    with pytest.raises(TypeError):

438
        class DummyTransform(BaseTransform):
439

Yining Li's avatar
Yining Li committed
440
            @cache_randomness
441
442
            def func(cls):
                return np.random.rand()
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: invalid mixture with cache_randomness
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            @cache_randomness
            def func(self):
                pass

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: raise error in cache_random_params
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            def transform(self, results):
                return results

        transform = DummyTransform()
        with cache_random_params(transform):
            pass

    # Test avoid_cache_randomness: non-inheritable
    @avoid_cache_randomness
    class DummyBaseTransform(BaseTransform):

        def transform(self, results):
            return results

    class DummyTransform(DummyBaseTransform):
        pass

    transform = DummyTransform()
    with cache_random_params(transform):
        pass