"vscode:/vscode.git/clone" did not exist on "9ebf10af02aeb882bd8d7782149e21b48528d562"
test_transforms_wrapper.py 12.2 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import numpy as np
import pytest

7
8
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.builder import TRANSFORMS
Yining Li's avatar
Yining Li committed
9
10
11
from mmcv.transforms.utils import cache_random_params, cache_randomness
from mmcv.transforms.wrappers import (Compose, KeyMapper, RandomChoice,
                                      TransformBroadcaster)
12
13
14
15


@TRANSFORMS.register_module()
class AddToValue(BaseTransform):
16
    """Dummy transform to add a given addend to results['value']"""
17

18
    def __init__(self, addend=0) -> None:
19
        super().__init__()
20
        self.addend = addend
21

22
    def add(self, results, addend):
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        augend = results['value']

        if isinstance(augend, list):
            warnings.warn('value is a list', UserWarning)
        if isinstance(augend, dict):
            warnings.warn('value is a dict', UserWarning)

        def _add_to_value(augend, addend):
            if isinstance(augend, list):
                return [_add_to_value(v, addend) for v in augend]
            if isinstance(augend, dict):
                return {k: _add_to_value(v, addend) for k, v in augend.items()}
            return augend + addend

        results['value'] = _add_to_value(results['value'], addend)
        return results

40
41
42
43
44
45
46
47
    def transform(self, results):
        return self.add(results, self.addend)


@TRANSFORMS.register_module()
class RandomAddToValue(AddToValue):
    """Dummy transform to add a random addend to results['value']"""

Yining Li's avatar
Yining Li committed
48
    def __init__(self, repeat=1) -> None:
49
        super().__init__(addend=None)
Yining Li's avatar
Yining Li committed
50
        self.repeat = repeat
51

Yining Li's avatar
Yining Li committed
52
    @cache_randomness
53
54
55
56
    def get_random_addend(self):
        return np.random.rand()

    def transform(self, results):
Yining Li's avatar
Yining Li committed
57
58
59
        for _ in range(self.repeat):
            results = self.add(results, addend=self.get_random_addend())
        return results
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

@TRANSFORMS.register_module()
class SumTwoValues(BaseTransform):
    """Dummy transform to test transform wrappers."""

    def transform(self, results):
        if 'num_1' in results and 'num_2' in results:
            results['sum'] = results['num_1'] + results['num_2']
        else:
            results['sum'] = np.nan
        return results


def test_compose():

    # Case 1: build from cfg
    pipeline = [dict(type='AddToValue')]
    pipeline = Compose(pipeline)
    _ = str(pipeline)

    # Case 2: build from transform list
    pipeline = [AddToValue()]
    pipeline = Compose(pipeline)

    # Case 3: invalid build arguments
    pipeline = [[dict(type='AddToValue')]]
    with pytest.raises(TypeError):
        pipeline = Compose(pipeline)

    # Case 4: contain transform with None output
    class DummyTransform(BaseTransform):

        def transform(self, results):
            return None

    pipeline = Compose([DummyTransform()])
    results = pipeline({})
    assert results is None


def test_cache_random_parameters():

103
    transform = RandomAddToValue()
104
105

    # Case 1: cache random parameters
Yining Li's avatar
Yining Li committed
106
107
    assert hasattr(RandomAddToValue, '_methods_with_randomness')
    assert 'get_random_addend' in RandomAddToValue._methods_with_randomness
108
109
110
111
112
113
114
115
116
117
118
119

    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])

    # Case 2: do not cache random parameters
    results_1 = transform(dict(value=0))
    results_2 = transform(dict(value=0))
    with pytest.raises(AssertionError):
        np.testing.assert_equal(results_1['value'], results_2['value'])

Yining Li's avatar
Yining Li committed
120
121
122
123
124
125
126
    # Case 3: allow to invoke random method 0 times
    transform = RandomAddToValue(repeat=0)
    with cache_random_params(transform):
        _ = transform(dict(value=0))

    # Case 4: NOT allow to invoke random method >1 times
    transform = RandomAddToValue(repeat=2)
127
128
    with pytest.raises(RuntimeError):
        with cache_random_params(transform):
Yining Li's avatar
Yining Li committed
129
            _ = transform(dict(value=0))
130

Yining Li's avatar
Yining Li committed
131
    # Case 5: apply on nested transforms
132
    transform = Compose([RandomAddToValue()])
133
134
135
136
137
138
    with cache_random_params(transform):
        results_1 = transform(dict(value=0))
        results_2 = transform(dict(value=0))
        np.testing.assert_equal(results_1['value'], results_2['value'])


Yining Li's avatar
Yining Li committed
139
def test_apply_to_mapped():
140
141

    # Case 1: simple remap
Yining Li's avatar
Yining Li committed
142
    pipeline = KeyMapper(
143
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
144
145
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
146
147
148
149
150
151
152
153
154

    results = dict(value=0, v_in=1)
    results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in'], 1)
    np.testing.assert_equal(results['v_out'], 2)

    # Case 2: collecting list
Yining Li's avatar
Yining Li committed
155
    pipeline = KeyMapper(
156
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
157
158
        mapping=dict(value=['v_in_1', 'v_in_2']),
        remapping=dict(value=['v_out_1', 'v_out_2']))
159
160
161
162
163
164
165
166
167
168
169
170
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

    # Case 3: collecting dict
Yining Li's avatar
Yining Li committed
171
    pipeline = KeyMapper(
172
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
173
174
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        remapping=dict(value=dict(v1='v_out_1', v2='v_out_2')))
175
176
177
178
179
180
181
182
183
184
185
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)  # should be unchanged
    np.testing.assert_equal(results['v_in_1'], 1)
    np.testing.assert_equal(results['v_in_2'], 2)
    np.testing.assert_equal(results['v_out_1'], 3)
    np.testing.assert_equal(results['v_out_2'], 4)

Yining Li's avatar
Yining Li committed
186
187
    # Case 4: collecting list with auto_remap mode
    pipeline = KeyMapper(
188
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
189
190
        mapping=dict(value=['v_in_1', 'v_in_2']),
        auto_remap=True)
191
192
193
194
195
196
197
198
199
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
200
201
    # Case 5: collecting dict with auto_remap mode
    pipeline = KeyMapper(
202
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
203
204
        mapping=dict(value=dict(v1='v_in_1', v2='v_in_2')),
        auto_remap=True)
205
206
207
208
209
210
211
212
213
    results = dict(value=0, v_in_1=1, v_in_2=2)

    with pytest.warns(UserWarning, match='value is a dict'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v_in_1'], 3)
    np.testing.assert_equal(results['v_in_2'], 4)

Yining Li's avatar
Yining Li committed
214
215
    # Case 6: nested collection with auto_remap mode
    pipeline = KeyMapper(
216
        transforms=[AddToValue(addend=2)],
Yining Li's avatar
Yining Li committed
217
218
        mapping=dict(value=['v1', dict(v2=['v21', 'v22'], v3='v3')]),
        auto_remap=True)
219
220
221
222
223
224
225
226
227
228
229
    results = dict(value=0, v1=1, v21=2, v22=3, v3=4)

    with pytest.warns(UserWarning, match='value is a list'):
        results = pipeline(results)

    np.testing.assert_equal(results['value'], 0)
    np.testing.assert_equal(results['v1'], 3)
    np.testing.assert_equal(results['v21'], 4)
    np.testing.assert_equal(results['v22'], 5)
    np.testing.assert_equal(results['v3'], 6)

Yining Li's avatar
Yining Li committed
230
    # Case 7: output_map must be None if `auto_remap` is set True
231
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
232
        pipeline = KeyMapper(
233
            transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
234
235
236
            mapping=dict(value='v_in'),
            remapping=dict(value='v_out'),
            auto_remap=True)
237

Yining Li's avatar
Yining Li committed
238
239
    # Case 8: allow_nonexist_keys8
    pipeline = KeyMapper(
240
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
241
242
243
        mapping=dict(num_1='a', num_2='b'),
        auto_remap=False,
        allow_nonexist_keys=True)
244
245
246
247
248
249
250

    results = pipeline(dict(a=1, b=2))
    np.testing.assert_equal(results['sum'], 3)

    results = pipeline(dict(a=1))
    assert np.isnan(results['sum'])

Yining Li's avatar
Yining Li committed
251
252
253
254
255
256
    # Case 9: use wrapper as a transform
    transform = KeyMapper(mapping=dict(b='a'), auto_remap=False)
    results = transform(dict(a=1))
    # note that the original key 'a' will not be removed
    assert results == dict(a=1, b=1)

257
    # Test basic functions
Yining Li's avatar
Yining Li committed
258
    pipeline = KeyMapper(
259
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
260
261
        mapping=dict(value='v_in'),
        remapping=dict(value='v_out'))
262
263
264
265
266
267
268
269
270
271
272
273

    # __iter__
    for _ in pipeline:
        pass

    # __repr__
    _ = str(pipeline)


def test_apply_to_multiple():

    # Case 1: apply to list in results
Yining Li's avatar
Yining Li committed
274
    pipeline = TransformBroadcaster(
275
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
276
277
        mapping=dict(value='values'),
        auto_remap=True)
278
279
280
281
282
283
284
    results = dict(values=[1, 2])

    results = pipeline(results)

    np.testing.assert_equal(results['values'], [2, 3])

    # Case 2: apply to multiple keys
Yining Li's avatar
Yining Li committed
285
    pipeline = TransformBroadcaster(
286
        transforms=[AddToValue(addend=1)],
Yining Li's avatar
Yining Li committed
287
288
        mapping=dict(value=['v_1', 'v_2']),
        auto_remap=True)
289
290
291
292
293
294
295
296
    results = dict(v_1=1, v_2=2)

    results = pipeline(results)

    np.testing.assert_equal(results['v_1'], 2)
    np.testing.assert_equal(results['v_2'], 3)

    # Case 3: apply to multiple groups of keys
Yining Li's avatar
Yining Li committed
297
    pipeline = TransformBroadcaster(
298
        transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
299
300
301
        mapping=dict(num_1=['a_1', 'b_1'], num_2=['a_2', 'b_2']),
        remapping=dict(sum=['a', 'b']),
        auto_remap=False)
302
303
304
305
306
307
308
309
310

    results = dict(a_1=1, a_2=2, b_1=3, b_2=4)
    results = pipeline(results)

    np.testing.assert_equal(results['a'], 3)
    np.testing.assert_equal(results['b'], 7)

    # Case 4: inconsistent sequence length
    with pytest.raises(ValueError):
Yining Li's avatar
Yining Li committed
311
        pipeline = TransformBroadcaster(
312
            transforms=[SumTwoValues()],
Yining Li's avatar
Yining Li committed
313
314
            mapping=dict(num_1='list_1', num_2='list_2'),
            auto_remap=False)
315
316
317
318
319

        results = dict(list_1=[1, 2], list_2=[1, 2, 3])
        _ = pipeline(results)

    # Case 5: share random parameter
Yining Li's avatar
Yining Li committed
320
    pipeline = TransformBroadcaster(
321
        transforms=[RandomAddToValue()],
Yining Li's avatar
Yining Li committed
322
323
        mapping=dict(value='values'),
        auto_remap=True,
324
        share_random_params=True)
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    results = dict(values=[0, 0])
    results = pipeline(results)

    np.testing.assert_equal(results['values'][0], results['values'][1])

    # Test repr
    _ = str(pipeline)


def test_randomchoice():

    # Case 1: given probability
    pipeline = RandomChoice(
Yining Li's avatar
Yining Li committed
339
340
        transforms=[[AddToValue(addend=1.0)], [AddToValue(addend=2.0)]],
        prob=[1.0, 0.0])
341
342
343
344

    results = pipeline(dict(value=1))
    np.testing.assert_equal(results['value'], 2.0)

345
    # Case 2: default probability
Yining Li's avatar
Yining Li committed
346
    pipeline = RandomChoice(transforms=[[AddToValue(
347
        addend=1.0)], [AddToValue(addend=2.0)]])
348
349
350

    _ = pipeline(dict(value=1))

Yining Li's avatar
Yining Li committed
351
352
    # Case 3: nested RandomChoice in TransformBroadcaster
    pipeline = TransformBroadcaster(
353
354
        transforms=[
            RandomChoice(
Yining Li's avatar
Yining Li committed
355
356
                transforms=[[AddToValue(addend=1.0)],
                            [AddToValue(addend=2.0)]], ),
357
        ],
Yining Li's avatar
Yining Li committed
358
359
        mapping=dict(value='values'),
        auto_remap=True,
360
361
362
363
364
365
366
367
        share_random_params=True)

    results = dict(values=[0 for _ in range(10)])
    results = pipeline(results)
    # check share_random_params=True works so that all values are same
    values = results['values']
    assert all(map(lambda x: x == values[0], values))

368
369

def test_utils():
Yining Li's avatar
Yining Li committed
370
    # Test cache_randomness: normal case
371
372
    class DummyTransform(BaseTransform):

Yining Li's avatar
Yining Li committed
373
        @cache_randomness
374
375
376
377
378
379
380
381
382
383
384
385
        def func(self):
            return np.random.rand()

        def transform(self, results):
            _ = self.func()
            return results

    transform = DummyTransform()
    _ = transform({})
    with cache_random_params(transform):
        _ = transform({})

Yining Li's avatar
Yining Li committed
386
    # Test cache_randomness: invalid function type
387
388
389
390
    with pytest.raises(TypeError):

        class DummyTransform():

Yining Li's avatar
Yining Li committed
391
            @cache_randomness
392
393
394
395
            @staticmethod
            def func():
                return np.random.rand()

Yining Li's avatar
Yining Li committed
396
    # Test cache_randomness: invalid function argument list
397
398
399
400
    with pytest.raises(TypeError):

        class DummyTransform():

Yining Li's avatar
Yining Li committed
401
            @cache_randomness
402
403
            def func(cls):
                return np.random.rand()