test_basemodule.py 22.3 KB
Newer Older
1
2
import tempfile

3
import pytest
4
5
6
import torch
from torch import nn

7
8
import mmcv
from mmcv.cnn.utils.weight_init import update_init_info
9
from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from mmcv.utils import Registry, build_from_cfg

COMPONENTS = Registry('component')
FOOMODELS = Registry('model')


@COMPONENTS.register_module()
class FooConv1d(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv1d = nn.Conv1d(4, 1, 4)

    def forward(self, x):
        return self.conv1d(x)


@COMPONENTS.register_module()
class FooConv2d(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.conv2d = nn.Conv2d(3, 1, 3)

    def forward(self, x):
        return self.conv2d(x)


@COMPONENTS.register_module()
class FooLinear(BaseModule):

    def __init__(self, init_cfg=None):
        super().__init__(init_cfg)
        self.linear = nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(x)


@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):

    def __init__(self, linear=None, conv1d=None, init_cfg=None):
        super().__init__(init_cfg)
        if linear is not None:
            self.linear = build_from_cfg(linear, COMPONENTS)
        if conv1d is not None:
            self.conv1d = build_from_cfg(conv1d, COMPONENTS)

    def forward(self, x):
        x = self.linear(x)
        return self.conv1d(x)


@FOOMODELS.register_module()
class FooModel(BaseModule):

    def __init__(self,
                 component1=None,
                 component2=None,
                 component3=None,
                 component4=None,
                 init_cfg=None) -> None:
        super().__init__(init_cfg)
        if component1 is not None:
            self.component1 = build_from_cfg(component1, COMPONENTS)
        if component2 is not None:
            self.component2 = build_from_cfg(component2, COMPONENTS)
        if component3 is not None:
            self.component3 = build_from_cfg(component3, COMPONENTS)
        if component4 is not None:
            self.component4 = build_from_cfg(component4, COMPONENTS)

        # its type is not BaseModule, it can be initialized
        # with "override" key.
        self.reg = nn.Linear(3, 4)


88
89
90
def test_initilization_info_logger():
    # 'override' has higher priority

91
92
    import os

93
    import torch.nn as nn
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    from mmcv.utils.logging import get_logger

    class OverloadInitConv(nn.Conv2d, BaseModule):

        def init_weights(self):
            for p in self.parameters():
                with torch.no_grad():
                    p.fill_(1)

    class CheckLoggerModel(BaseModule):

        def __init__(self, init_cfg=None):
            super(CheckLoggerModel, self).__init__(init_cfg)
            self.conv1 = nn.Conv2d(1, 1, 1, 1)
            self.conv2 = OverloadInitConv(1, 1, 1, 1)
            self.conv3 = nn.Conv2d(1, 1, 1, 1)
            self.fc1 = nn.Linear(1, 1)

    init_cfg = [
        dict(
            type='Normal',
            layer='Conv2d',
            std=0.01,
            override=dict(
                type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
        dict(type='Constant', layer='Linear', val=0., bias=1.)
    ]

    model = CheckLoggerModel(init_cfg=init_cfg)

    train_log = '20210720_132454.log'
    workdir = tempfile.mkdtemp()
    log_file = os.path.join(workdir, train_log)
    # create a logger
    get_logger('init_logger', log_file=log_file)
130
    assert not hasattr(model, '_params_init_info')
131
132
133
134
135
136
    model.init_weights()
    # assert `_params_init_info` would be deleted after `init_weights`
    assert not hasattr(model, '_params_init_info')
    # assert initialization information has been dumped
    assert os.path.exists(log_file)

137
138
    lines = mmcv.list_from_file(log_file)

139
    # check initialization information is right
140
    for i, line in enumerate(lines):
141
        if 'conv1.weight' in line:
142
            assert 'NormalInit' in lines[i + 1]
143
        if 'conv2.weight' in line:
144
            assert 'OverloadInitConv' in lines[i + 1]
145
        if 'fc1.weight' in line:
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            assert 'ConstantInit' in lines[i + 1]

    # test corner case

    class OverloadInitConvFc(nn.Conv2d, BaseModule):

        def __init__(self, *args, **kwargs):
            super(OverloadInitConvFc, self).__init__(*args, **kwargs)
            self.conv1 = nn.Linear(1, 1)

        def init_weights(self):
            for p in self.parameters():
                with torch.no_grad():
                    p.fill_(1)

    class CheckLoggerModel(BaseModule):

        def __init__(self, init_cfg=None):
            super(CheckLoggerModel, self).__init__(init_cfg)
            self.conv1 = nn.Conv2d(1, 1, 1, 1)
            self.conv2 = OverloadInitConvFc(1, 1, 1, 1)
            self.conv3 = nn.Conv2d(1, 1, 1, 1)
            self.fc1 = nn.Linear(1, 1)

    class TopLevelModule(BaseModule):

        def __init__(self, init_cfg=None, checklog_init_cfg=None):
            super(TopLevelModule, self).__init__(init_cfg)
            self.module1 = CheckLoggerModel(checklog_init_cfg)
            self.module2 = OverloadInitConvFc(1, 1, 1, 1)

    checklog_init_cfg = [
        dict(
            type='Normal',
            layer='Conv2d',
            std=0.01,
            override=dict(
                type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
        dict(type='Constant', layer='Linear', val=0., bias=1.)
    ]

    top_level_init_cfg = [
        dict(
            type='Normal',
            layer='Conv2d',
            std=0.01,
            override=dict(
                type='Normal', name='module2', std=0.01, bias_prob=0.01))
    ]

    model = TopLevelModule(
        init_cfg=top_level_init_cfg, checklog_init_cfg=checklog_init_cfg)

    model.module1.init_weights()
    model.module2.init_weights()
    model.init_weights()
    model.module1.init_weights()
    model.module2.init_weights()

    assert not hasattr(model, '_params_init_info')
    model.init_weights()
    # assert `_params_init_info` would be deleted after `init_weights`
    assert not hasattr(model, '_params_init_info')
    # assert initialization information has been dumped
    assert os.path.exists(log_file)

    lines = mmcv.list_from_file(log_file)
    # check initialization information is right
    for i, line in enumerate(lines):
        if 'TopLevelModule' in line and 'init_cfg' not in line:
            # have been set init_flag
            assert 'the same' in line
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246


def test_update_init_info():

    class DummyModel(BaseModule):

        def __init__(self, init_cfg=None):
            super().__init__(init_cfg)
            self.conv1 = nn.Conv2d(1, 1, 1, 1)
            self.conv3 = nn.Conv2d(1, 1, 1, 1)
            self.fc1 = nn.Linear(1, 1)

    model = DummyModel()
    from collections import defaultdict
    model._params_init_info = defaultdict(dict)
    for name, param in model.named_parameters():
        model._params_init_info[param]['init_info'] = 'init'
        model._params_init_info[param]['tmp_mean_value'] = param.data.mean()

    with torch.no_grad():
        for p in model.parameters():
            p.fill_(1)

    update_init_info(model, init_info='fill_1')

    for item in model._params_init_info.values():
        assert item['init_info'] == 'fill_1'
        assert item['tmp_mean_value'] == 1

247
248
249
250
251
    # test assert for new parameters
    model.conv1.bias = nn.Parameter(torch.ones_like(model.conv1.bias))
    with pytest.raises(AssertionError):
        update_init_info(model, init_info=' ')

252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def test_model_weight_init():
    """
    Config
    model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
                     Conv2d: weight=5, bias=6)
    ├──component1 (FooConv1d)
    ├──component2 (FooConv2d)
    ├──component3 (FooLinear)
    ├──component4 (FooLinearConv1d)
        ├──linear (FooLinear)
        ├──conv1d (FooConv1d)
    ├──reg (nn.Linear)

    Parameters after initialization
    model (FooModel)
    ├──component1 (FooConv1d, weight=3, bias=4)
    ├──component2 (FooConv2d, weight=5, bias=6)
    ├──component3 (FooLinear, weight=1, bias=2)
    ├──component4 (FooLinearConv1d)
        ├──linear (FooLinear, weight=1, bias=2)
        ├──conv1d (FooConv1d, weight=3, bias=4)
    ├──reg (nn.Linear, weight=1, bias=2)
    """
    model_cfg = dict(
        type='FooModel',
        init_cfg=[
            dict(type='Constant', val=1, bias=2, layer='Linear'),
            dict(type='Constant', val=3, bias=4, layer='Conv1d'),
            dict(type='Constant', val=5, bias=6, layer='Conv2d')
        ],
        component1=dict(type='FooConv1d'),
        component2=dict(type='FooConv2d'),
        component3=dict(type='FooLinear'),
        component4=dict(
            type='FooLinearConv1d',
            linear=dict(type='FooLinear'),
            conv1d=dict(type='FooConv1d')))

    model = build_from_cfg(model_cfg, FOOMODELS)
292
    model.init_weights()
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    assert torch.equal(model.component1.conv1d.weight,
                       torch.full(model.component1.conv1d.weight.shape, 3.0))
    assert torch.equal(model.component1.conv1d.bias,
                       torch.full(model.component1.conv1d.bias.shape, 4.0))
    assert torch.equal(model.component2.conv2d.weight,
                       torch.full(model.component2.conv2d.weight.shape, 5.0))
    assert torch.equal(model.component2.conv2d.bias,
                       torch.full(model.component2.conv2d.bias.shape, 6.0))
    assert torch.equal(model.component3.linear.weight,
                       torch.full(model.component3.linear.weight.shape, 1.0))
    assert torch.equal(model.component3.linear.bias,
                       torch.full(model.component3.linear.bias.shape, 2.0))
    assert torch.equal(
        model.component4.linear.linear.weight,
        torch.full(model.component4.linear.linear.weight.shape, 1.0))
    assert torch.equal(
        model.component4.linear.linear.bias,
        torch.full(model.component4.linear.linear.bias.shape, 2.0))
    assert torch.equal(
        model.component4.conv1d.conv1d.weight,
        torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
    assert torch.equal(
        model.component4.conv1d.conv1d.bias,
        torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
    assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
                                                    1.0))
    assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))


def test_nest_components_weight_init():
    """
    Config
    model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
                     Conv2d: weight=5, bias=6)
    ├──component1 (FooConv1d, Conv1d: weight=7, bias=8)
    ├──component2 (FooConv2d, Conv2d: weight=9, bias=10)
    ├──component3 (FooLinear)
    ├──component4 (FooLinearConv1d, Linear: weight=11, bias=12)
        ├──linear (FooLinear, Linear: weight=11, bias=12)
        ├──conv1d (FooConv1d)
    ├──reg (nn.Linear, weight=13, bias=14)

    Parameters after initialization
    model (FooModel)
    ├──component1 (FooConv1d, weight=7, bias=8)
    ├──component2 (FooConv2d, weight=9, bias=10)
    ├──component3 (FooLinear, weight=1, bias=2)
    ├──component4 (FooLinearConv1d)
        ├──linear (FooLinear, weight=1, bias=2)
        ├──conv1d (FooConv1d, weight=3, bias=4)
    ├──reg (nn.Linear, weight=13, bias=14)
    """

    model_cfg = dict(
        type='FooModel',
        init_cfg=[
            dict(
                type='Constant',
                val=1,
                bias=2,
                layer='Linear',
                override=dict(type='Constant', name='reg', val=13, bias=14)),
            dict(type='Constant', val=3, bias=4, layer='Conv1d'),
            dict(type='Constant', val=5, bias=6, layer='Conv2d'),
        ],
        component1=dict(
360
361
            type='FooConv1d',
            init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)),
362
        component2=dict(
363
364
            type='FooConv2d',
            init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)),
365
366
367
368
369
370
371
        component3=dict(type='FooLinear'),
        component4=dict(
            type='FooLinearConv1d',
            linear=dict(type='FooLinear'),
            conv1d=dict(type='FooConv1d')))

    model = build_from_cfg(model_cfg, FOOMODELS)
372
    model.init_weights()
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

    assert torch.equal(model.component1.conv1d.weight,
                       torch.full(model.component1.conv1d.weight.shape, 7.0))
    assert torch.equal(model.component1.conv1d.bias,
                       torch.full(model.component1.conv1d.bias.shape, 8.0))
    assert torch.equal(model.component2.conv2d.weight,
                       torch.full(model.component2.conv2d.weight.shape, 9.0))
    assert torch.equal(model.component2.conv2d.bias,
                       torch.full(model.component2.conv2d.bias.shape, 10.0))
    assert torch.equal(model.component3.linear.weight,
                       torch.full(model.component3.linear.weight.shape, 1.0))
    assert torch.equal(model.component3.linear.bias,
                       torch.full(model.component3.linear.bias.shape, 2.0))
    assert torch.equal(
        model.component4.linear.linear.weight,
        torch.full(model.component4.linear.linear.weight.shape, 1.0))
    assert torch.equal(
        model.component4.linear.linear.bias,
        torch.full(model.component4.linear.linear.bias.shape, 2.0))
    assert torch.equal(
        model.component4.conv1d.conv1d.weight,
        torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
    assert torch.equal(
        model.component4.conv1d.conv1d.bias,
        torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
    assert torch.equal(model.reg.weight,
                       torch.full(model.reg.weight.shape, 13.0))
    assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
401
402


403
404
405
406
407
408
409
410
411
412
413
414
415
def test_without_layer_weight_init():
    model_cfg = dict(
        type='FooModel',
        init_cfg=[
            dict(type='Constant', val=1, bias=2, layer='Linear'),
            dict(type='Constant', val=3, bias=4, layer='Conv1d'),
            dict(type='Constant', val=5, bias=6, layer='Conv2d')
        ],
        component1=dict(
            type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
        component2=dict(type='FooConv2d'),
        component3=dict(type='FooLinear'))
    model = build_from_cfg(model_cfg, FOOMODELS)
416
    model.init_weights()
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    assert torch.equal(model.component1.conv1d.weight,
                       torch.full(model.component1.conv1d.weight.shape, 3.0))
    assert torch.equal(model.component1.conv1d.bias,
                       torch.full(model.component1.conv1d.bias.shape, 4.0))

    # init_cfg in component1 does not have layer key, so it does nothing
    assert torch.equal(model.component2.conv2d.weight,
                       torch.full(model.component2.conv2d.weight.shape, 5.0))
    assert torch.equal(model.component2.conv2d.bias,
                       torch.full(model.component2.conv2d.bias.shape, 6.0))
    assert torch.equal(model.component3.linear.weight,
                       torch.full(model.component3.linear.weight.shape, 1.0))
    assert torch.equal(model.component3.linear.bias,
                       torch.full(model.component3.linear.bias.shape, 2.0))

    assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
                                                    1.0))
    assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))


def test_override_weight_init():

    # only initialize 'override'
    model_cfg = dict(
        type='FooModel',
        init_cfg=[
            dict(type='Constant', val=10, bias=20, override=dict(name='reg'))
        ],
        component1=dict(type='FooConv1d'),
        component3=dict(type='FooLinear'))
    model = build_from_cfg(model_cfg, FOOMODELS)
449
    model.init_weights()
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    assert torch.equal(model.reg.weight,
                       torch.full(model.reg.weight.shape, 10.0))
    assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
    # do not initialize others
    assert not torch.equal(
        model.component1.conv1d.weight,
        torch.full(model.component1.conv1d.weight.shape, 10.0))
    assert not torch.equal(
        model.component1.conv1d.bias,
        torch.full(model.component1.conv1d.bias.shape, 20.0))
    assert not torch.equal(
        model.component3.linear.weight,
        torch.full(model.component3.linear.weight.shape, 10.0))
    assert not torch.equal(
        model.component3.linear.bias,
        torch.full(model.component3.linear.bias.shape, 20.0))

    # 'override' has higher priority
    model_cfg = dict(
        type='FooModel',
        init_cfg=[
            dict(
                type='Constant',
                val=1,
                bias=2,
                override=dict(name='reg', type='Constant', val=30, bias=40))
        ],
        component1=dict(type='FooConv1d'),
        component2=dict(type='FooConv2d'),
        component3=dict(type='FooLinear'))
    model = build_from_cfg(model_cfg, FOOMODELS)
481
    model.init_weights()
482
483
484
485
486
487

    assert torch.equal(model.reg.weight,
                       torch.full(model.reg.weight.shape, 30.0))
    assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0))


488
489
490
def test_sequential_model_weight_init():
    seq_model_cfg = [
        dict(
491
492
            type='FooConv1d',
            init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
493
        dict(
494
495
            type='FooConv2d',
            init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
496
497
498
    ]
    layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
    seq_model = Sequential(*layers)
499
    seq_model.init_weights()
500
501
502
503
504
505
506
507
    assert torch.equal(seq_model[0].conv1d.weight,
                       torch.full(seq_model[0].conv1d.weight.shape, 0.))
    assert torch.equal(seq_model[0].conv1d.bias,
                       torch.full(seq_model[0].conv1d.bias.shape, 1.))
    assert torch.equal(seq_model[1].conv2d.weight,
                       torch.full(seq_model[1].conv2d.weight.shape, 2.))
    assert torch.equal(seq_model[1].conv2d.bias,
                       torch.full(seq_model[1].conv2d.bias.shape, 3.))
508
    # inner init_cfg has higher priority
509
    layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
510
    seq_model = Sequential(
511
512
513
        *layers,
        init_cfg=dict(
            type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
514
    seq_model.init_weights()
515
516
517
518
519
520
521
522
523
524
525
526
527
    assert torch.equal(seq_model[0].conv1d.weight,
                       torch.full(seq_model[0].conv1d.weight.shape, 0.))
    assert torch.equal(seq_model[0].conv1d.bias,
                       torch.full(seq_model[0].conv1d.bias.shape, 1.))
    assert torch.equal(seq_model[1].conv2d.weight,
                       torch.full(seq_model[1].conv2d.weight.shape, 2.))
    assert torch.equal(seq_model[1].conv2d.bias,
                       torch.full(seq_model[1].conv2d.bias.shape, 3.))


def test_modulelist_weight_init():
    models_cfg = [
        dict(
528
529
            type='FooConv1d',
            init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
530
        dict(
531
532
            type='FooConv2d',
            init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
533
534
535
    ]
    layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
    modellist = ModuleList(layers)
536
    modellist.init_weights()
537
538
539
540
541
542
543
544
    assert torch.equal(modellist[0].conv1d.weight,
                       torch.full(modellist[0].conv1d.weight.shape, 0.))
    assert torch.equal(modellist[0].conv1d.bias,
                       torch.full(modellist[0].conv1d.bias.shape, 1.))
    assert torch.equal(modellist[1].conv2d.weight,
                       torch.full(modellist[1].conv2d.weight.shape, 2.))
    assert torch.equal(modellist[1].conv2d.bias,
                       torch.full(modellist[1].conv2d.bias.shape, 3.))
545
    # inner init_cfg has higher priority
546
    layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
547
    modellist = ModuleList(
548
549
550
        layers,
        init_cfg=dict(
            type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
551
    modellist.init_weights()
552
553
554
555
556
557
558
559
    assert torch.equal(modellist[0].conv1d.weight,
                       torch.full(modellist[0].conv1d.weight.shape, 0.))
    assert torch.equal(modellist[0].conv1d.bias,
                       torch.full(modellist[0].conv1d.bias.shape, 1.))
    assert torch.equal(modellist[1].conv2d.weight,
                       torch.full(modellist[1].conv2d.weight.shape, 2.))
    assert torch.equal(modellist[1].conv2d.bias,
                       torch.full(modellist[1].conv2d.bias.shape, 3.))
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610


def test_moduledict_weight_init():
    models_cfg = dict(
        foo_conv_1d=dict(
            type='FooConv1d',
            init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
        foo_conv_2d=dict(
            type='FooConv2d',
            init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
    )
    layers = {
        name: build_from_cfg(cfg, COMPONENTS)
        for name, cfg in models_cfg.items()
    }
    modeldict = ModuleDict(layers)
    modeldict.init_weights()
    assert torch.equal(
        modeldict['foo_conv_1d'].conv1d.weight,
        torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
    assert torch.equal(
        modeldict['foo_conv_1d'].conv1d.bias,
        torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
    assert torch.equal(
        modeldict['foo_conv_2d'].conv2d.weight,
        torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
    assert torch.equal(
        modeldict['foo_conv_2d'].conv2d.bias,
        torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))
    # inner init_cfg has higher priority
    layers = {
        name: build_from_cfg(cfg, COMPONENTS)
        for name, cfg in models_cfg.items()
    }
    modeldict = ModuleDict(
        layers,
        init_cfg=dict(
            type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
    modeldict.init_weights()
    assert torch.equal(
        modeldict['foo_conv_1d'].conv1d.weight,
        torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
    assert torch.equal(
        modeldict['foo_conv_1d'].conv1d.bias,
        torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
    assert torch.equal(
        modeldict['foo_conv_2d'].conv2d.weight,
        torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
    assert torch.equal(
        modeldict['foo_conv_2d'].conv2d.bias,
        torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))