test_registry.py 6.4 KB
Newer Older
1
2
3
4
5
6
import pytest

import mmcv


def test_registry():
7
8
    CATS = mmcv.Registry('cat')
    assert CATS.name == 'cat'
9
10
11
    assert CATS.module_dict == {}
    assert len(CATS) == 0

12
    @CATS.register_module()
13
14
15
16
17
18
19
20
21
22
23
24
    class BritishShorthair:
        pass

    assert len(CATS) == 1
    assert CATS.get('BritishShorthair') is BritishShorthair

    class Munchkin:
        pass

    CATS.register_module(Munchkin)
    assert len(CATS) == 2
    assert CATS.get('Munchkin') is Munchkin
Joanna's avatar
Joanna committed
25
    assert 'Munchkin' in CATS
26
27
28
29
30
31
32

    with pytest.raises(KeyError):
        CATS.register_module(Munchkin)

    CATS.register_module(Munchkin, force=True)
    assert len(CATS) == 2

33
    # force=False
34
35
    with pytest.raises(KeyError):

36
        @CATS.register_module()
37
38
39
40
41
42
43
44
45
46
        class BritishShorthair:
            pass

    @CATS.register_module(force=True)
    class BritishShorthair:
        pass

    assert len(CATS) == 2

    assert CATS.get('PersianCat') is None
Joanna's avatar
Joanna committed
47
    assert 'PersianCat' not in CATS
48

49
50
51
52
53
54
55
56
57
58
59
60
    @CATS.register_module(name='Siamese')
    class SiameseCat:
        pass

    assert CATS.get('Siamese').__name__ == 'SiameseCat'

    class SphynxCat:
        pass

    CATS.register_module(name='Sphynx', module=SphynxCat)
    assert CATS.get('Sphynx') is SphynxCat

61
62
63
    CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
    assert CATS.get('Sphynx2') is SphynxCat

64
65
66
67
68
69
70
71
    repr_str = 'Registry(name=cat, items={'
    repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
                 "<locals>.BritishShorthair'>, ")
    repr_str += ("'Munchkin': <class 'test_registry.test_registry."
                 "<locals>.Munchkin'>, ")
    repr_str += ("'Siamese': <class 'test_registry.test_registry."
                 "<locals>.SiameseCat'>, ")
    repr_str += ("'Sphynx': <class 'test_registry.test_registry."
72
73
74
75
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
                 "<locals>.SphynxCat'>, ")
    repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
76
77
78
                 "<locals>.SphynxCat'>")
    repr_str += '})'
    assert repr(CATS) == repr_str
79

80
81
82
83
    # name type
    with pytest.raises(AssertionError):
        CATS.register_module(name=7474741, module=SphynxCat)

84
85
86
87
    # the registered module should be a class
    with pytest.raises(TypeError):
        CATS.register_module(0)

88
89
90
91
92
93
94
95
    # can only decorate a class
    with pytest.raises(TypeError):

        @CATS.register_module()
        def some_method():
            pass

    # begin: test old APIs
96
    with pytest.warns(UserWarning):
97
98
99
        CATS.register_module(SphynxCat)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

100
    with pytest.warns(UserWarning):
101
102
103
        CATS.register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

104
    with pytest.warns(UserWarning):
105
106
107
108
109
110
111

        @CATS.register_module
        class NewCat:
            pass

        assert CATS.get('NewCat').__name__ == 'NewCat'

112
    with pytest.warns(UserWarning):
113
114
115
        CATS.deprecated_register_module(SphynxCat, force=True)
        assert CATS.get('SphynxCat').__name__ == 'SphynxCat'

116
    with pytest.warns(UserWarning):
117
118
119
120
121
122
123

        @CATS.deprecated_register_module
        class CuteCat:
            pass

        assert CATS.get('CuteCat').__name__ == 'CuteCat'

124
    with pytest.warns(UserWarning):
125
126
127
128
129
130
131
132
133

        @CATS.deprecated_register_module(force=True)
        class NewCat2:
            pass

        assert CATS.get('NewCat2').__name__ == 'NewCat2'

    # end: test old APIs

134
135
136
137

def test_build_from_cfg():
    BACKBONES = mmcv.Registry('backbone')

138
    @BACKBONES.register_module()
139
140
141
142
143
144
    class ResNet:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

145
    @BACKBONES.register_module()
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    class ResNeXt:

        def __init__(self, depth, stages=4):
            self.depth = depth
            self.stages = stages

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(type='ResNet', depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type='ResNeXt', depth=50, stages=3)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNeXt)
    assert model.depth == 50 and model.stages == 3

    cfg = dict(type=ResNet, depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES)
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

172
173
174
175
176
177
178
179
180
181
182
183
    # type defined using default_args
    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(
        cfg, BACKBONES, default_args=dict(type='ResNet'))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

    cfg = dict(depth=50)
    model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
    assert isinstance(model, ResNet)
    assert model.depth == 50 and model.stages == 4

184
185
186
187
188
    # not a registry
    with pytest.raises(TypeError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

189
190
191
192
193
    # non-registered class
    with pytest.raises(KeyError):
        cfg = dict(type='VGG')
        model = mmcv.build_from_cfg(cfg, BACKBONES)

194
195
196
197
198
    # default_args must be a dict or None
    with pytest.raises(TypeError):
        cfg = dict(type='ResNet', depth=50)
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)

199
200
201
202
203
204
    # cfg['type'] should be a str or class
    with pytest.raises(TypeError):
        cfg = dict(type=1000)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

    # cfg should contain the key "type"
205
    with pytest.raises(KeyError, match='must contain the key "type"'):
206
207
208
        cfg = dict(depth=50, stages=4)
        model = mmcv.build_from_cfg(cfg, BACKBONES)

209
210
211
212
213
214
    # cfg or default_args should contain the key "type"
    with pytest.raises(KeyError, match='must contain the key "type"'):
        cfg = dict(depth=50)
        model = mmcv.build_from_cfg(
            cfg, BACKBONES, default_args=dict(stages=4))

215
216
    # incorrect registry type
    with pytest.raises(TypeError):
Kai Chen's avatar
Kai Chen committed
217
        cfg = dict(type='ResNet', depth=50)
218
219
220
221
        model = mmcv.build_from_cfg(cfg, 'BACKBONES')

    # incorrect default_args type
    with pytest.raises(TypeError):
Kai Chen's avatar
Kai Chen committed
222
        cfg = dict(type='ResNet', depth=50)
223
        model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)