test_registry.py 8.28 KB
Newer Older
1
2
3
4
5
6
import pytest

import mmcv


def test_registry():
7
8
    CATS = mmcv.Registry('cat')
    assert CATS.name == 'cat'
9
10
11
    assert CATS.module_dict == {}
    assert len(CATS) == 0

12
    @CATS.register_module()
13
14
15
16
17
18
19
20
21
22
23
24
    class BritishShorthair:
        pass

    assert len(CATS) == 1
    assert CATS.get('BritishShorthair') is BritishShorthair

    class Munchkin:
        pass

    CATS.register_module(Munchkin)
    assert len(CATS) == 2
    assert CATS.get('Munchkin') is Munchkin
Joanna's avatar
Joanna committed
25
    assert 'Munchkin' in CATS
26
27
28
29
30
31
32

    with pytest.raises(KeyError):
        CATS.register_module(Munchkin)

    CATS.register_module(Munchkin, force=True)
    assert len(CATS) == 2

33
    # force=False
34
35
    with pytest.raises(KeyError):

36
        @CATS.register_module()
37
38
39
40
41
42
43
44
45
46
        class BritishShorthair:
            pass

    @CATS.register_module(force=True)
    class BritishShorthair:
        pass

    assert len(CATS) == 2

    assert CATS.get('PersianCat') is None
Joanna's avatar
Joanna committed
47
    assert 'PersianCat' not in CATS
48

49
    @CATS.register_module(name=['Siamese', 'Siamese2'])
50
51
52
53
    class SiameseCat:
        pass

    assert CATS.get('Siamese').__name__ == 'SiameseCat'
54
    assert CATS.get('Siamese2').__name__ == 'SiameseCat'
55
56
57
58
59
60
61

    class SphynxCat:
        pass

    CATS.register_module(name='Sphynx', module=SphynxCat)
    assert CATS.get('Sphynx') is SphynxCat

62
63
64
    CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
    assert CATS.get('Sphynx2') is SphynxCat

65
66
67
68
69
70
71
    repr_str = 'Registry(name=cat, items={'
    repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
                 "<locals>.BritishShorthair'>, ")
    repr_str += ("'Munchkin': <class 'test_registry.test_registry."
                 "<locals>.Munchkin'>, ")
    repr_str += ("'Siamese': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
72
73
    repr_str += ("'Siamese2': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
74
    repr_str += ("'Sphynx': <class 'test_registry.test_registry."
75
76
77
78
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
79
80
81
                 "<locals>.SphynxCat'>")
    repr_str += '})'
    assert repr(CATS) == repr_str
82

83
    # name type
84
    with pytest.raises(TypeError):
85
86
        CATS.register_module(name=7474741, module=SphynxCat)

87
88
89
90
    # the registered module should be a class
    with pytest.raises(TypeError):
        CATS.register_module(0)

91
92
93
94
95
96
97
98
    # can only decorate a class
    with pytest.raises(TypeError):

        @CATS.register_module()
        def some_method():
            pass

    # begin: test old APIs
99
    with pytest.warns(UserWarning):
100
101
102
        CATS.register_module(SphynxCat)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

103
    with pytest.warns(UserWarning):
104
105
106
        CATS.register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

107
    with pytest.warns(UserWarning):
108
109
110
111
112
113
114

        @CATS.register_module
        class NewCat:
            pass

        assert CATS.get('NewCat').__name__ == 'NewCat'

115
    with pytest.warns(UserWarning):
116
117
118
        CATS.deprecated_register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

119
    with pytest.warns(UserWarning):
120
121
122
123
124
125
126

        @CATS.deprecated_register_module
        class CuteCat:
            pass

        assert CATS.get('CuteCat').__name__ == 'CuteCat'

127
    with pytest.warns(UserWarning):
128
129
130
131
132
133
134
135
136

        @CATS.deprecated_register_module(force=True)
        class NewCat2:
            pass

        assert CATS.get('NewCat2').__name__ == 'NewCat2'

    # end: test old APIs

137

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def test_multi_scope_registry():
    DOGS = mmcv.Registry('dogs')
    assert DOGS.name == 'dogs'
    assert DOGS.scope == 'test_registry'
    assert DOGS.module_dict == {}
    assert len(DOGS) == 0

    @DOGS.register_module()
    class GoldenRetriever:
        pass

    assert len(DOGS) == 1
    assert DOGS.get('GoldenRetriever') is GoldenRetriever

    HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound')

    @HOUNDS.register_module()
    class BloodHound:
        pass

    assert len(HOUNDS) == 1
    assert HOUNDS.get('BloodHound') is BloodHound
    assert DOGS.get('hound.BloodHound') is BloodHound
    assert HOUNDS.get('hound.BloodHound') is BloodHound

    LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound')

    @LITTLE_HOUNDS.register_module()
    class Dachshund:
        pass

    assert len(LITTLE_HOUNDS) == 1
    assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
    assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
    assert HOUNDS.get('little_hound.Dachshund') is Dachshund
    assert DOGS.get('hound.little_hound.Dachshund') is Dachshund

    MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound')

    @MID_HOUNDS.register_module()
    class Beagle:
        pass

    assert MID_HOUNDS.get('Beagle') is Beagle
    assert HOUNDS.get('mid_hound.Beagle') is Beagle
    assert DOGS.get('hound.mid_hound.Beagle') is Beagle
    assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
    assert MID_HOUNDS.get('hound.BloodHound') is BloodHound
    assert MID_HOUNDS.get('hound.Dachshund') is None


189
190
191
def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

192
    @BACKBONES.register_module()
193
194
195
196
197
198
    class ResNet:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

199
    @BACKBONES.register_module()
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    class ResNeXt:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

226
227
228
229
230
231
232
233
234
235
236
237
    # type defined using default_args
    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(
        cfg, BACKBONES, default_args=dict(type='ResNet'))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

238
239
240
241
242
    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

243
244
245
246
247
    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

248
249
250
251
252
    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)

253
254
255
256
257
258
    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
259
    with pytest.raises(KeyError, match='must contain the key "type"'):
260
261
262
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

263
264
265
266
267
268
    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = mmcv.build_from_cfg(
            cfg, BACKBONES, default_args=dict(stages=4))

269
270
    # incorrect registry type
    with pytest.raises(TypeError):
Kai Chen's avatar
Kai Chen committed
271
        cfg = dict(type='ResNet', depth=50)
272
273
274
275
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
Kai Chen's avatar
Kai Chen committed
276
        cfg = dict(type='ResNet', depth=50)
277
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
278
279
280
281
282

    # incorrect arguments
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', non_existing_arg=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES)