"docs/source/en/api/pipelines/model_editing.md" did not exist on "37a44bb2839c1af18940b6cf38f5639c9c279caf"
test_transforms_wrapper.py 15 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
9
10
from mmcv.transforms.utils import (avoid_cache_randomness, cache_random_params,
                                   cache_randomness)
Yining Li's avatar
Yining Li committed
11
12
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomApply,
                                      RandomChoice, TransformBroadcaster)
13
14
15
16


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
17
    """Dummy transform to add a given addend to results['value']"""
18

19
    def __init__(self, addend=0) -> None:
20
        super().__init__()
21
        self.addend = addend
22

23
    def add(self, results, addend):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

41
42
43
44
45
46
47
48
    def transform(self, results):
        return self.add(results, self.addend)


@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
49
    def __init__(self, repeat=1) -> None:
50
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
51
        self.repeat = repeat
52

Yining Li's avatar
Yining Li committed
53
    @cache_randomness
54
55
56
57
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
58
59
60
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
        else:
            results['sum'] = np.nan
        return results


def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

104
    transform = RandomAddToValue()
105
106

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
107
108
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
109
110
111
112
113
114
115
116
117
118
119
120

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
121
122
123
124
125
126
127
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
128
129
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
130
            _ = transform(dict(value=0))
131

Yining Li's avatar
Yining Li committed
132
    # Case 5: apply on nested transforms
133
    transform = Compose([RandomAddToValue()])
134
135
136
137
138
139
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
140
def test_key_mapper():
141
142
143
144
145
146
147
148
149
    # Case 0: only remap
    pipeline = KeyMapper(
        transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})

    results = dict(value=0)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_out'], 1)
150
151

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
152
    pipeline = KeyMapper(
153
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
154
155
        mapping={'value': 'v_in'},
        remapping={'value': 'v_out'})
156
157
158
159
160
161
162
163
164

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
165
    pipeline = KeyMapper(
166
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
167
168
        mapping={'value': ['v_in_1', 'v_in_2']},
        remapping={'value': ['v_out_1', 'v_out_2']})
169
170
171
172
173
174
175
176
177
178
179
180
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
181
    pipeline = KeyMapper(
182
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
183
184
185
186
187
188
189
190
        mapping={'value': {
            'v1': 'v_in_1',
            'v2': 'v_in_2'
        }},
        remapping={'value': {
            'v1': 'v_out_1',
            'v2': 'v_out_2'
        }})
191
192
193
194
195
196
197
198
199
200
201
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
202
203
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
204
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
205
206
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
207
208
209
210
211
212
213
214
215
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
216
217
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
218
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
219
220
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
221
222
223
224
225
226
227
228
229
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
230
231
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
232
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
233
234
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
235
236
237
238
239
240
241
242
243
244
245
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
246
    # Case 7: output_map must be None if `auto_remap` is set True
247
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
248
        pipeline = KeyMapper(
249
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
250
251
252
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
253

Yining Li's avatar
Yining Li committed
254
255
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
256
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
257
258
259
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
260
261
262
263
264
265
266

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
    assert np.isnan(results['sum'])

Yining Li's avatar
Yining Li committed
267
268
269
270
271
272
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

273
    # Test basic functions
Yining Li's avatar
Yining Li committed
274
    pipeline = KeyMapper(
275
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
276
277
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
278
279
280
281
282
283
284
285
286

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
    _ = str(pipeline)


287
def test_transform_broadcaster():
288
289

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
290
    pipeline = TransformBroadcaster(
291
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
292
293
        mapping=dict(value='values'),
        auto_remap=True)
294
295
296
297
298
299
300
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
301
    pipeline = TransformBroadcaster(
302
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
303
304
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
305
306
307
308
309
310
311
312
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
313
    pipeline = TransformBroadcaster(
314
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
315
316
317
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
318
319
320
321
322
323
324

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

325
326
327
328
329
330
331
332
333
    # Case 3: apply to all keys
    pipeline = TransformBroadcaster(
        transforms=[SumTwoValues()], mapping=None, remapping=None)
    results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])

    results = pipeline(results)

    np.testing.assert_equal(results['sum'], [5, 7, 9])

334
335
    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
336
        pipeline = TransformBroadcaster(
337
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
338
339
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
340
341
342
343
344

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
345
    pipeline = TransformBroadcaster(
346
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
347
348
        mapping=dict(value='values'),
        auto_remap=True,
349
        share_random_params=True)
350
351
352
353
354
355
356
357
358
359

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

    # Test repr
    _ = str(pipeline)


Yining Li's avatar
Yining Li committed
360
def test_random_choice():
361
362
363

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
364
365
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
366
367
368
369

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

370
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
371
    pipeline = RandomChoice(transforms=[[AddToValue(
372
        addend=1.0)], [AddToValue(addend=2.0)]])
373
374
375

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
376
377
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
378
379
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
380
381
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
382
        ],
Yining Li's avatar
Yining Li committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        mapping={'value': 'values'},
        auto_remap=True,
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))


def test_random_apply():

    # Case 1: simple use
    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=1.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

    pipeline = RandomApply(transforms=[AddToValue(addend=1.0)], prob=0.0)
    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 1.0)

    # Case 2: nested RandomApply in TransformBroadcaster
    pipeline = TransformBroadcaster(
        transforms=[RandomApply(transforms=[AddToValue(addend=1)], prob=0.5)],
        mapping={'value': 'values'},
Yining Li's avatar
Yining Li committed
409
        auto_remap=True,
410
411
412
413
414
415
416
417
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

Yining Li's avatar
Yining Li committed
418
419
420
421
    # __iter__
    for _ in pipeline:
        pass

422
423

def test_utils():
Yining Li's avatar
Yining Li committed
424
    # Test cache_randomness: normal case
425
426
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
427
        @cache_randomness
428
429
430
431
432
433
434
435
436
437
438
439
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
440
    # Test cache_randomness: invalid function type
441
442
    with pytest.raises(TypeError):

443
        class DummyTransform(BaseTransform):
444

Yining Li's avatar
Yining Li committed
445
            @cache_randomness
446
447
448
449
            @staticmethod
            def func():
                return np.random.rand()

450
451
452
            def transform(self, results):
                return results

Yining Li's avatar
Yining Li committed
453
    # Test cache_randomness: invalid function argument list
454
455
    with pytest.raises(TypeError):

456
        class DummyTransform(BaseTransform):
457

Yining Li's avatar
Yining Li committed
458
            @cache_randomness
459
460
            def func(cls):
                return np.random.rand()
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: invalid mixture with cache_randomness
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            @cache_randomness
            def func(self):
                pass

            def transform(self, results):
                return results

    # Test avoid_cache_randomness: raise error in cache_random_params
    with pytest.raises(RuntimeError):

        @avoid_cache_randomness
        class DummyTransform(BaseTransform):

            def transform(self, results):
                return results

        transform = DummyTransform()
        with cache_random_params(transform):
            pass

    # Test avoid_cache_randomness: non-inheritable
    @avoid_cache_randomness
    class DummyBaseTransform(BaseTransform):

        def transform(self, results):
            return results

    class DummyTransform(DummyBaseTransform):
        pass

    transform = DummyTransform()
    with cache_random_params(transform):
        pass