Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved.
from importlib import import_module
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS,
build_activation_layer, build_conv_layer,
from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
......@@ -63,18 +64,19 @@ def test_build_conv_layer():
# sparse convs cannot support the case when groups>1
kwargs.pop('groups')
for type_name, module in CONV_LAYERS.module_dict.items():
for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name)
# SparseInverseConv2d and SparseInverseConv3d do not have the argument
# 'dilation'
if type_name == 'SparseInverseConv2d' or type_name == \
'SparseInverseConv3d':
kwargs.pop('dilation')
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, module)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
if 'conv' in type_name.lower():
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, module)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
def test_infer_norm_abbr():
......@@ -154,7 +156,9 @@ def test_build_norm_layer():
'IN2d': 'in',
'IN3d': 'in',
}
for type_name, module in NORM_LAYERS.module_dict.items():
for type_name, module in MODELS.module_dict.items():
if type_name not in abbr_mapping:
continue
if type_name == 'MMSyncBN': # skip MMSyncBN
continue
for postfix in ['_test', 1]:
......@@ -172,6 +176,17 @@ def test_build_norm_layer():
def test_build_activation_layer():
act_names = [
'ReLU', 'LeakyReLU', 'PReLU', 'RReLU', 'ReLU6', 'ELU', 'Sigmoid',
'Tanh'
]
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
act_names.append(key)
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'ReLU'
......@@ -188,10 +203,11 @@ def test_build_activation_layer():
build_activation_layer(cfg)
# test each type of activation layer in activation_cfg
for type_name, module in ACTIVATION_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_activation_layer(cfg)
assert isinstance(layer, module)
for type_name, module in MODELS.module_dict.items():
if type_name in act_names:
cfg['type'] = type_name
layer = build_activation_layer(cfg)
assert isinstance(layer, module)
# sanity check for Clamp
act = build_activation_layer(dict(type='Clamp'))
......@@ -207,6 +223,13 @@ def test_build_activation_layer():
def test_build_padding_layer():
pad_names = ['zero', 'reflect', 'replicate']
for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
pad_names.append(key)
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'reflect'
......@@ -222,10 +245,11 @@ def test_build_padding_layer():
cfg = dict(type='FancyPad')
build_padding_layer(cfg)
for type_name, module in PADDING_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
for type_name, module in MODELS.module_dict.items():
if type_name in pad_names:
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect')
......@@ -377,22 +401,21 @@ def test_build_plugin_layer():
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ContextBlock'])
assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention
for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer,
PLUGIN_LAYERS.module_dict['GeneralizedAttention'])
assert isinstance(layer, MODELS.module_dict['GeneralizedAttention'])
# test NonLocal2d
for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['NonLocal2d'])
assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule
for postfix in ['', '_test', 1]:
......@@ -404,4 +427,4 @@ def test_build_plugin_layer():
out_channels=4,
kernel_size=3)
assert name == 'conv_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ConvModule'])
assert isinstance(layer, MODELS.module_dict['ConvModule'])
......@@ -5,12 +5,13 @@ from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
from mmcv.cnn.bricks import ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module()
@MODELS.register_module()
class ExampleConv(nn.Module):
def __init__(self,
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import mmcv
from mmcv.cnn import MODELS, build_model_from_cfg
def test_build_model_from_cfg():
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = mmcv.Registry(
'models', parent=MODELS, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build
......@@ -3,6 +3,7 @@ import copy
import pytest
import torch
from mmengine.model import ModuleList
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
......@@ -10,7 +11,6 @@ from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
MultiheadAttention, PatchEmbed,
PatchMerging,
TransformerLayerSequence)
from mmcv.runner import ModuleList
def test_adaptive_padding():
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from tempfile import TemporaryDirectory
import numpy as np
import pytest
import torch
from scipy import stats
from torch import nn
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
PretrainedInit, TruncNormalInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init,
initialize, kaiming_init, normal_init, trunc_normal_init,
uniform_init, xavier_init)
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_trunc_normal_init():
def _random_float(a, b):
return (b - a) * random.random() + a
def _is_trunc_normal(tensor, mean, std, a, b):
# scipy's trunc norm is suited for data drawn from N(0, 1),
# so we need to transform our data to test it using scipy.
z_samples = (tensor.view(-1) - mean) / std
z_samples = z_samples.tolist()
a0 = (a - mean) / std
b0 = (b - mean) / std
p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1]
return p_value > 0.0001
conv_module = nn.Conv2d(3, 16, 3)
mean = _random_float(-3, 3)
std = _random_float(.01, 1)
a = _random_float(mean - 2 * std, mean)
b = _random_float(mean, mean + 2 * std)
trunc_normal_init(conv_module, mean, std, a, b, bias=0.1)
assert _is_trunc_normal(conv_module.weight, mean, std, a, b)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
trunc_normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
def test_constaninit():
"""test ConstantInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = ConstantInit(val=1, bias=2, layer='Conv2d')
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 1.))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
func(model)
res = bias_init_with_prob(0.01)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
# test bias_prob type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias_prob='1')
# test layer input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, layer=1)
def test_xavierinit():
"""test XavierInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
func = XavierInit(gain=100, bias_prob=0.01, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
res = bias_init_with_prob(0.01)
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = ConstantInit(val=4., bias=5., layer='_ConvNd')
func(model)
assert torch.all(model[0].weight == 4.)
assert torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == 5.)
assert torch.all(model[2].bias == 5.)
func = XavierInit(gain=100, bias_prob=0.01, layer='_ConvNd')
func(model)
assert not torch.all(model[0].weight == 4.)
assert not torch.all(model[2].weight == 4.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
# test layer inpur type
with pytest.raises(TypeError):
func = XavierInit(bias=0.1, layer=1)
def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = NormalInit(mean=100, std=1e-5, bias=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = NormalInit(
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = NormalInit(mean=300, std=1e-5, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_truncnormalinit():
"""test TruncNormalInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = TruncNormalInit(
mean=100, std=1e-5, bias=200, a=0, b=200, layer=['Conv2d', 'Linear'])
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = TruncNormalInit(
mean=300,
std=1e-5,
a=100,
b=400,
bias_prob=0.01,
layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = TruncNormalInit(
mean=300, std=1e-5, a=100, b=400, bias_prob=0.01, layer='_ConvNd')
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2, layer=['Conv2d', 'Linear'])
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
100.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
100.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = UniformInit(a=100, b=100, bias_prob=0.01, layer='_ConvNd')
res = bias_init_with_prob(0.01)
func(model)
assert torch.all(model[0].weight == 100.)
assert torch.all(model[2].weight == 100.)
assert torch.all(model[0].bias == res)
assert torch.all(model[2].bias == res)
def test_kaiminginit():
"""test KaimingInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = KaimingInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
func = KaimingInit(a=100, bias=10, layer=['Conv2d', 'Linear'])
constant_func = ConstantInit(val=0, bias=0, layer=['Conv2d', 'Linear'])
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
# test layer key with base class name
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Conv1d(1, 2, 1))
func = KaimingInit(bias=0.1, layer='_ConvNd')
func(model)
assert torch.all(model[0].bias == 0.1)
assert torch.all(model[2].bias == 0.1)
func = KaimingInit(a=100, bias=10, layer='_ConvNd')
constant_func = ConstantInit(val=0, bias=0, layer='_ConvNd')
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
def test_pretrainedinit():
"""test PretrainedInit class."""
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
modelC = nn.Linear(1, 2)
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
funcB(modelB)
assert torch.equal(modelB.linear.weight,
torch.full(modelB.linear.weight.shape, 1.))
assert torch.equal(modelB.linear.bias,
torch.full(modelB.linear.bias.shape, 2.))
assert torch.equal(modelB.conv2d.weight,
torch.full(modelB.conv2d.weight.shape, 1.))
assert torch.equal(modelB.conv2d.bias,
torch.full(modelB.conv2d.bias.shape, 2.))
assert torch.equal(modelB.conv2d_2.weight,
torch.full(modelB.conv2d_2.weight.shape, 1.))
assert torch.equal(modelB.conv2d_2.bias,
torch.full(modelB.conv2d_2.bias.shape, 2.))
funcC(modelC)
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
# test layer key
init_cfg = dict(type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
assert init_cfg == dict(
type='Constant', layer=['Conv2d', 'Linear'], val=1, bias=2)
# test init_cfg with list type
init_cfg = [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
assert init_cfg == [
dict(type='Constant', layer='Conv2d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
# test layer key and override key
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test override key
init_cfg = dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
initialize(foonet, init_cfg)
assert not torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 5.))
assert not torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 6.))
assert not torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 5.))
assert not torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 6.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 5.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 6.))
assert init_cfg == dict(
type='Constant', val=5, bias=6, override=dict(name='conv2d_2'))
init_cfg = dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2, layer=['Conv2d', 'Linear'])
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
assert init_cfg == dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
initialize(foonet, init_cfg)
# test override value type
with pytest.raises(TypeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override='conv')
initialize(foonet, init_cfg)
# test override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
initialize(foonet, init_cfg)
# test list override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=[
dict(type='Constant', name='conv2d', val=3, bias=4),
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)
# test override with args except type key
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
# test override without name
with pytest.raises(ValueError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
override=dict(type='Constant', val=3, bias=4))
initialize(foonet, init_cfg)
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import mmengine
import pytest
import torch
from torch import nn
from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component')
FOOMODELS = Registry('model')
@COMPONENTS.register_module()
class FooConv1d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)
def forward(self, x):
return self.conv1d(x)
@COMPONENTS.register_module()
class FooConv2d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)
def forward(self, x):
return self.conv2d(x)
@COMPONENTS.register_module()
class FooLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):
def __init__(self, linear=None, conv1d=None, init_cfg=None):
super().__init__(init_cfg)
if linear is not None:
self.linear = build_from_cfg(linear, COMPONENTS)
if conv1d is not None:
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
def forward(self, x):
x = self.linear(x)
return self.conv1d(x)
@FOOMODELS.register_module()
class FooModel(BaseModule):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(init_cfg)
if component1 is not None:
self.component1 = build_from_cfg(component1, COMPONENTS)
if component2 is not None:
self.component2 = build_from_cfg(component2, COMPONENTS)
if component3 is not None:
self.component3 = build_from_cfg(component3, COMPONENTS)
if component4 is not None:
self.component4 = build_from_cfg(component4, COMPONENTS)
# its type is not BaseModule, it can be initialized
# with "override" key.
self.reg = nn.Linear(3, 4)
def test_initilization_info_logger():
# 'override' has higher priority
import os
import torch.nn as nn
from mmcv.utils.logging import get_logger
class OverloadInitConv(nn.Conv2d, BaseModule):
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConv(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
model = CheckLoggerModel(init_cfg=init_cfg)
train_log = '20210720_132454.log'
workdir = tempfile.mkdtemp()
log_file = os.path.join(workdir, train_log)
# create a logger
get_logger('init_logger', log_file=log_file)
assert not hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmengine.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
if 'conv1.weight' in line:
assert 'NormalInit' in lines[i + 1]
if 'conv2.weight' in line:
assert 'OverloadInitConv' in lines[i + 1]
if 'fc1.weight' in line:
assert 'ConstantInit' in lines[i + 1]
# test corner case
class OverloadInitConvFc(nn.Conv2d, BaseModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv1 = nn.Linear(1, 1)
def init_weights(self):
for p in self.parameters():
with torch.no_grad():
p.fill_(1)
class CheckLoggerModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv2 = OverloadInitConvFc(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
class TopLevelModule(BaseModule):
def __init__(self, init_cfg=None, checklog_init_cfg=None):
super().__init__(init_cfg)
self.module1 = CheckLoggerModel(checklog_init_cfg)
self.module2 = OverloadInitConvFc(1, 1, 1, 1)
checklog_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv3', std=0.01, bias_prob=0.01)),
dict(type='Constant', layer='Linear', val=0., bias=1.)
]
top_level_init_cfg = [
dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='module2', std=0.01, bias_prob=0.01))
]
model = TopLevelModule(
init_cfg=top_level_init_cfg, checklog_init_cfg=checklog_init_cfg)
model.module1.init_weights()
model.module2.init_weights()
model.init_weights()
model.module1.init_weights()
model.module2.init_weights()
assert not hasattr(model, '_params_init_info')
model.init_weights()
# assert `_params_init_info` would be deleted after `init_weights`
assert not hasattr(model, '_params_init_info')
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmengine.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
if 'TopLevelModule' in line and 'init_cfg' not in line:
# have been set init_flag
assert 'the same' in line
def test_update_init_info():
class DummyModel(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1 = nn.Conv2d(1, 1, 1, 1)
self.conv3 = nn.Conv2d(1, 1, 1, 1)
self.fc1 = nn.Linear(1, 1)
model = DummyModel()
from collections import defaultdict
model._params_init_info = defaultdict(dict)
for name, param in model.named_parameters():
model._params_init_info[param]['init_info'] = 'init'
model._params_init_info[param]['tmp_mean_value'] = param.data.mean()
with torch.no_grad():
for p in model.parameters():
p.fill_(1)
update_init_info(model, init_info='fill_1')
for item in model._params_init_info.values():
assert item['init_info'] == 'fill_1'
assert item['tmp_mean_value'] == 1
# test assert for new parameters
model.conv1.bias = nn.Parameter(torch.ones_like(model.conv1.bias))
with pytest.raises(AssertionError):
update_init_info(model, init_info=' ')
def test_model_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d)
├──component2 (FooConv2d)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d)
├──linear (FooLinear)
├──conv1d (FooConv1d)
├──reg (nn.Linear)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=3, bias=4)
├──component2 (FooConv2d, weight=5, bias=6)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=1, bias=2)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
def test_nest_components_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d, Conv1d: weight=7, bias=8)
├──component2 (FooConv2d, Conv2d: weight=9, bias=10)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d, Linear: weight=11, bias=12)
├──linear (FooLinear, Linear: weight=11, bias=12)
├──conv1d (FooConv1d)
├──reg (nn.Linear, weight=13, bias=14)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=7, bias=8)
├──component2 (FooConv2d, weight=9, bias=10)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=13, bias=14)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
layer='Linear',
override=dict(type='Constant', name='reg', val=13, bias=14)),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d'),
],
component1=dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=7, bias=8)),
component2=dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=9, bias=10)),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 7.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 8.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 9.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 10.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
def test_without_layer_weight_init():
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))
# init_cfg in component1 does not have layer key, so it does nothing
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
def test_override_weight_init():
# only initialize 'override'
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=10, bias=20, override=dict(name='reg'))
],
component1=dict(type='FooConv1d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 10.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 20.0))
# do not initialize others
assert not torch.equal(
model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 10.0))
assert not torch.equal(
model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 20.0))
assert not torch.equal(
model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 10.0))
assert not torch.equal(
model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 20.0))
# 'override' has higher priority
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
override=dict(name='reg', type='Constant', val=30, bias=40))
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weights()
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 30.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 40.0))
def test_sequential_model_weight_init():
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
assert torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.))
assert torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.))
assert torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.))
assert torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.))
def test_modulelist_weight_init():
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
assert torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.))
assert torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.))
assert torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.))
assert torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.))
def test_moduledict_weight_init():
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))
assert torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))
assert torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment