Commit 1baf0566 authored by limm's avatar limm
Browse files

add tests part

parent 495d9ed9
Pipeline #2800 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import math
from copy import deepcopy
from itertools import chain
from unittest import TestCase
import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmpretrain.models.backbones import HorNet
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.7.0'),
reason='torch.fft is not available before 1.7.0')
class TestHorNet(TestCase):
def setUp(self):
self.cfg = dict(
arch='t', drop_path_rate=0.1, gap_before_final_norm=False)
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
HorNet(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'depths': [1, 1, 1, 1],
'orders': [1, 1, 1, 1],
}
HorNet(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
base_dim = 64
depths = [2, 3, 18, 2]
embed_dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8]
cfg['arch'] = {
'base_dim':
base_dim,
'depths':
depths,
'orders': [2, 3, 4, 5],
'dw_cfg': [
dict(type='DW', kernel_size=7),
dict(type='DW', kernel_size=7),
dict(type='GF', h=14, w=8),
dict(type='GF', h=7, w=4)
],
}
model = HorNet(**cfg)
for i in range(len(depths)):
stage = model.stages[i]
self.assertEqual(stage[-1].out_channels, embed_dims[i])
self.assertEqual(len(stage), depths[i])
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = HorNet(**cfg)
ori_weight = model.downsample_layers[0][0].weight.clone().detach()
model.init_weights()
initialized_weight = model.downsample_layers[0][0].weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = HorNet(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (3, 512, 7, 7))
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 1, 2, 3)
model = HorNet(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for emb_size, stride, out in zip([64, 128, 256, 512], [1, 2, 4, 8],
outs):
self.assertEqual(out.shape,
(3, emb_size, 56 // stride, 56 // stride))
# test with dynamic input shape
imgs1 = torch.randn(3, 3, 224, 224)
imgs2 = torch.randn(3, 3, 256, 256)
imgs3 = torch.randn(3, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = HorNet(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
expect_feat_shape = (math.floor(imgs.shape[2] / 32),
math.floor(imgs.shape[3] / 32))
self.assertEqual(feat.shape, (3, 512, *expect_feat_shape))
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = HorNet(**cfg)
depths = model.arch_settings['depths']
stages = model.stages
blocks = chain(*[stage for stage in stages])
total_depth = sum(depths)
dpr = [
x.item()
for x in torch.linspace(0, cfg['drop_path_rate'], total_depth)
]
for i, (block, expect_prob) in enumerate(zip(blocks, dpr)):
if expect_prob == 0:
assert isinstance(block.drop_path, nn.Identity)
else:
self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob)
# test VAN with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 0
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 1, 2, 3)
model = HorNet(**cfg)
model.init_weights()
model.train()
# the patch_embed and first stage should not require grad.
for i in range(frozen_stages + 1):
down = model.downsample_layers[i]
for param in down.parameters():
self.assertFalse(param.requires_grad)
blocks = model.stages[i]
for param in blocks.parameters():
self.assertFalse(param.requires_grad)
# the second stage should require grad.
for i in range(frozen_stages + 1, 4):
down = model.downsample_layers[i]
for param in down.parameters():
self.assertTrue(param.requires_grad)
blocks = model.stages[i]
for param in blocks.parameters():
self.assertTrue(param.requires_grad)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import HRNet
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
@pytest.mark.parametrize('base_channels', [18, 30, 32, 40, 44, 48, 64])
def test_hrnet_arch_zoo(base_channels):
cfg_ori = dict(arch=f'w{base_channels}')
# Test HRNet model with input size of 224
model = HRNet(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
outs = model(imgs)
out_channels = base_channels
out_size = 56
assert isinstance(outs, tuple)
for out in outs:
assert out.shape == (3, out_channels, out_size, out_size)
out_channels = out_channels * 2
out_size = out_size // 2
def test_hrnet_custom_arch():
cfg_ori = dict(
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BOTTLENECK',
num_blocks=(4, 4, 2),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 3, 4, 4),
num_channels=(32, 64, 152, 256)),
), )
# Test HRNet model with input size of 224
model = HRNet(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
outs = model(imgs)
out_channels = (32, 64, 152, 256)
out_size = 56
assert isinstance(outs, tuple)
for out, out_channel in zip(outs, out_channels):
assert out.shape == (3, out_channel, out_size, out_size)
out_size = out_size // 2
# Copyright (c) OpenMMLab. All rights reserved.
from types import MethodType
from unittest import TestCase
import torch
from mmpretrain.models import InceptionV3
from mmpretrain.models.backbones.inception_v3 import InceptionAux
class TestInceptionV3(TestCase):
DEFAULT_ARGS = dict(num_classes=10, aux_logits=False, dropout=0.)
def test_structure(self):
# Test without auxiliary branch.
model = InceptionV3(**self.DEFAULT_ARGS)
self.assertIsNone(model.AuxLogits)
# Test with auxiliary branch.
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
self.assertIsInstance(model.AuxLogits, InceptionAux)
def test_init_weights(self):
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
init_info = {}
def get_init_info(self, *args):
for name, param in self.named_parameters():
init_info[name] = ''.join(
self._params_init_info[param]['init_info'])
model._dump_init_info = MethodType(get_init_info, model)
model.init_weights()
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.1, bias=0',
init_info['Conv2d_1a_3x3.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.01, bias=0',
init_info['AuxLogits.conv0.conv.weight'])
self.assertIn('TruncNormalInit: a=-2, b=2, mean=0, std=0.001, bias=0',
init_info['AuxLogits.fc.weight'])
def test_forward(self):
inputs = torch.rand(2, 3, 299, 299)
model = InceptionV3(**self.DEFAULT_ARGS)
aux_out, out = model(inputs)
self.assertIsNone(aux_out)
self.assertEqual(out.shape, (2, 10))
cfg = {**self.DEFAULT_ARGS, 'aux_logits': True}
model = InceptionV3(**cfg)
aux_out, out = model(inputs)
self.assertEqual(aux_out.shape, (2, 10))
self.assertEqual(out.shape, (2, 10))
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import pytest
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from torch import nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import levit
from mmpretrain.models.backbones.levit import (Attention, AttentionSubsample,
LeViT)
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def is_levit_block(modules):
if isinstance(modules, (AttentionSubsample, Attention)):
return True
return False
def test_levit_attention():
block = Attention(128, 16, 4, 2, act_cfg=dict(type='HSwish'))
block.eval()
x = torch.randn(1, 196, 128)
y = block(x)
assert y.shape == x.shape
assert hasattr(block, 'ab')
assert block.key_dim == 16
assert block.attn_ratio == 2
assert block.num_heads == 4
assert block.qkv.linear.in_features == 128
def test_levit():
with pytest.raises(TypeError):
# arch must be str or dict
LeViT(arch=[4, 6, 16, 1])
with pytest.raises(AssertionError):
# arch must in arch_settings
LeViT(arch='512')
with pytest.raises(AssertionError):
arch = dict(num_blocks=[2, 4, 14, 1])
LeViT(arch=arch)
# Test out_indices not type of int or Sequence
with pytest.raises(TypeError):
LeViT('128s', out_indices=dict())
# Test max(out_indices) < len(arch['num_blocks'])
with pytest.raises(AssertionError):
LeViT('128s', out_indices=(3, ))
model = LeViT('128s', out_indices=(-1, ))
assert model.out_indices == [2]
model = LeViT(arch='256', drop_path_rate=0.1)
model.eval()
assert model.key_dims == [32, 32, 32]
assert model.embed_dims == [256, 384, 512]
assert model.num_heads == [4, 6, 8]
assert model.depths == [4, 4, 4]
assert model.drop_path_rate == 0.1
assert isinstance(model.stages[0][0].block.qkv, levit.LinearBatchNorm)
assert isinstance(model.patch_embed.patch_embed[0],
levit.ConvolutionBatchNorm)
model = LeViT(
arch='128s',
hybrid_backbone=lambda embed_dims: nn.Conv2d(
embed_dims, embed_dims, kernel_size=2))
model.eval()
assert isinstance(model.patch_embed, nn.Conv2d)
# Test eval of "train" mode and "deploy" mode
model = LeViT(arch='128s', deploy=True)
model.eval()
assert not isinstance(model.stages[0][0].block.qkv, levit.LinearBatchNorm)
assert not isinstance(model.patch_embed.patch_embed[0],
levit.ConvolutionBatchNorm)
assert isinstance(model.stages[0][0].block.qkv, nn.Linear)
assert isinstance(model.patch_embed.patch_embed[0], nn.Conv2d)
# Test LeViT forward with layer 2 forward
model = LeViT('128s', out_indices=(2, ))
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 384, 4, 4))
# Test LeViT forward
arch_settings = {
'128s': dict(out_channels=[128, 256, 384]),
'128': dict(out_channels=[128, 256, 384]),
'192': dict(out_channels=[192, 288, 384]),
'256': dict(out_channels=[256, 384, 512]),
'384': dict(out_channels=[384, 512, 768])
}
choose_models = ['128s', '192', '256', '384']
# Test LeViT model forward
for model_name, model_arch in arch_settings.items():
if model_name not in choose_models:
continue
model = LeViT(model_name, out_indices=(0, 1, 2))
model.init_weights()
# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_arch['out_channels'][0], 14, 14))
assert feat[1].shape == torch.Size(
(1, model_arch['out_channels'][1], 7, 7))
assert feat[2].shape == torch.Size(
(1, model_arch['out_channels'][2], 4, 4))
def test_load_deploy_LeViT():
# Test output before and load from deploy checkpoint
model = LeViT('128s', out_indices=(0, 1, 2))
inputs = torch.randn((1, 3, 224, 224))
tmpdir = tempfile.gettempdir()
ckpt_path = os.path.join(tmpdir, 'ckpt.pth')
model.switch_to_deploy()
model.eval()
outputs = model(inputs)
model_deploy = LeViT('128s', out_indices=(0, 1, 2), deploy=True)
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path)
outputs_load = model_deploy(inputs)
for feat, feat_load in zip(outputs, outputs_load):
assert torch.allclose(feat, feat_load)
os.remove(ckpt_path)
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmpretrain.models.backbones import MixMIMTransformer
class TestMixMIM(TestCase):
def setUp(self):
self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1)
def test_structure(self):
# Test custom arch
cfg = deepcopy(self.cfg)
model = MixMIMTransformer(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(sum(model.depths), 24)
self.assertIsNotNone(model.absolute_pos_embed)
num_heads = [4, 8, 16, 32]
for i, layer in enumerate(model.layers):
self.assertEqual(layer.blocks[0].num_heads, num_heads[i])
self.assertEqual(layer.blocks[0].ffn.feedforward_channels,
128 * (2**i) * 4)
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = MixMIMTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
averaged_token = outs[-1]
self.assertEqual(averaged_token.shape, (1, 1024))
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import MlpMixer
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
class TestMLPMixer(TestCase):
def setUp(self):
self.cfg = dict(
arch='b',
img_size=224,
patch_size=16,
drop_rate=0.1,
init_cfg=[
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
])
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
MlpMixer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 24,
'num_layers': 16,
'tokens_mlp_dims': 4096
}
MlpMixer(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 6,
'tokens_mlp_dims': 256,
'channels_mlp_dims': 1024
}
model = MlpMixer(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 6)
for layer in model.layers:
self.assertEqual(layer.token_mix.feedforward_channels, 256)
self.assertEqual(layer.channel_mix.feedforward_channels, 1024)
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = MlpMixer(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
# test forward with single out indices
cfg = deepcopy(self.cfg)
model = MlpMixer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (1, 768, 196))
# test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = MlpMixer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for feat in outs:
self.assertEqual(feat.shape, (1, 768, 196))
# test with invalid input shape
imgs2 = torch.randn(1, 3, 256, 256)
cfg = deepcopy(self.cfg)
model = MlpMixer(**cfg)
with self.assertRaisesRegex(AssertionError, 'dynamic input shape.'):
model(imgs2)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import MobileNetV2
from mmpretrain.models.backbones.mobilenet_v2 import InvertedResidual
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (InvertedResidual, )):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_mobilenetv2_invertedresidual():
with pytest.raises(AssertionError):
# stride must be in [1, 2]
InvertedResidual(16, 24, stride=3, expand_ratio=6)
# Test InvertedResidual with checkpoint forward, stride=1
block = InvertedResidual(16, 24, stride=1, expand_ratio=6)
x = torch.randn(1, 16, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 56, 56))
# Test InvertedResidual with expand_ratio=1
block = InvertedResidual(16, 16, stride=1, expand_ratio=1)
assert len(block.conv) == 2
# Test InvertedResidual with use_res_connect
block = InvertedResidual(16, 16, stride=1, expand_ratio=6)
x = torch.randn(1, 16, 56, 56)
x_out = block(x)
assert block.use_res_connect is True
assert x_out.shape == torch.Size((1, 16, 56, 56))
# Test InvertedResidual with checkpoint forward, stride=2
block = InvertedResidual(16, 24, stride=2, expand_ratio=6)
x = torch.randn(1, 16, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 28, 28))
# Test InvertedResidual with checkpoint forward
block = InvertedResidual(16, 24, stride=1, expand_ratio=6, with_cp=True)
assert block.with_cp
x = torch.randn(1, 16, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 56, 56))
# Test InvertedResidual with act_cfg=dict(type='ReLU')
block = InvertedResidual(
16, 24, stride=1, expand_ratio=6, act_cfg=dict(type='ReLU'))
x = torch.randn(1, 16, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 56, 56))
def test_mobilenetv2_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = MobileNetV2()
model.init_weights(pretrained=0)
with pytest.raises(ValueError):
# frozen_stages must in range(-1, 8)
MobileNetV2(frozen_stages=8)
with pytest.raises(ValueError):
# out_indices in range(0, 8)
MobileNetV2(out_indices=[8])
# Test MobileNetV2 with first stage frozen
frozen_stages = 1
model = MobileNetV2(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for mod in model.conv1.modules():
for param in mod.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test MobileNetV2 with norm_eval=True
model = MobileNetV2(norm_eval=True)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test MobileNetV2 forward with widen_factor=1.0
model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 8))
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))
assert feat[7].shape == torch.Size((1, 1280, 7, 7))
# Test MobileNetV2 forward with widen_factor=0.5
model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 8, 112, 112))
assert feat[1].shape == torch.Size((1, 16, 56, 56))
assert feat[2].shape == torch.Size((1, 16, 28, 28))
assert feat[3].shape == torch.Size((1, 32, 14, 14))
assert feat[4].shape == torch.Size((1, 48, 14, 14))
assert feat[5].shape == torch.Size((1, 80, 7, 7))
assert feat[6].shape == torch.Size((1, 160, 7, 7))
# Test MobileNetV2 forward with widen_factor=2.0
model = MobileNetV2(widen_factor=2.0)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 2560, 7, 7))
# Test MobileNetV2 forward with out_indices=None
model = MobileNetV2(widen_factor=1.0)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
# Test MobileNetV2 forward with dict(type='ReLU')
model = MobileNetV2(
widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))
# Test MobileNetV2 with BatchNorm forward
model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))
# Test MobileNetV2 with GroupNorm forward
model = MobileNetV2(
widen_factor=1.0,
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
out_indices=range(0, 7))
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))
# Test MobileNetV2 with layers 1, 3, 5 out forward
model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 32, 28, 28))
assert feat[2].shape == torch.Size((1, 96, 14, 14))
# Test MobileNetV2 with checkpoint forward
model = MobileNetV2(
widen_factor=1.0, with_cp=True, out_indices=range(0, 7))
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size((1, 16, 112, 112))
assert feat[1].shape == torch.Size((1, 24, 56, 56))
assert feat[2].shape == torch.Size((1, 32, 28, 28))
assert feat[3].shape == torch.Size((1, 64, 14, 14))
assert feat[4].shape == torch.Size((1, 96, 14, 14))
assert feat[5].shape == torch.Size((1, 160, 7, 7))
assert feat[6].shape == torch.Size((1, 320, 7, 7))
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import MobileNetV3
from mmpretrain.models.utils import InvertedResidual
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_mobilenetv3_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = MobileNetV3()
model.init_weights(pretrained=0)
with pytest.raises(AssertionError):
# arch must in [small, large]
MobileNetV3(arch='others')
with pytest.raises(ValueError):
# frozen_stages must less than 13 when arch is small
MobileNetV3(arch='small', frozen_stages=13)
with pytest.raises(ValueError):
# frozen_stages must less than 17 when arch is large
MobileNetV3(arch='large', frozen_stages=17)
with pytest.raises(ValueError):
# max out_indices must less than 13 when arch is small
MobileNetV3(arch='small', out_indices=(13, ))
with pytest.raises(ValueError):
# max out_indices must less than 17 when arch is large
MobileNetV3(arch='large', out_indices=(17, ))
# Test MobileNetV3
model = MobileNetV3()
model.init_weights()
model.train()
# Test MobileNetV3 with first stage frozen
frozen_stages = 1
model = MobileNetV3(frozen_stages=frozen_stages)
model.init_weights()
model.train()
for i in range(0, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test MobileNetV3 with norm eval
model = MobileNetV3(norm_eval=True, out_indices=range(0, 12))
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test MobileNetV3 forward with small arch
model = MobileNetV3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 13
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 16, 56, 56])
assert feat[2].shape == torch.Size([1, 24, 28, 28])
assert feat[3].shape == torch.Size([1, 24, 28, 28])
assert feat[4].shape == torch.Size([1, 40, 14, 14])
assert feat[5].shape == torch.Size([1, 40, 14, 14])
assert feat[6].shape == torch.Size([1, 40, 14, 14])
assert feat[7].shape == torch.Size([1, 48, 14, 14])
assert feat[8].shape == torch.Size([1, 48, 14, 14])
assert feat[9].shape == torch.Size([1, 96, 7, 7])
assert feat[10].shape == torch.Size([1, 96, 7, 7])
assert feat[11].shape == torch.Size([1, 96, 7, 7])
assert feat[12].shape == torch.Size([1, 576, 7, 7])
# Test MobileNetV3 forward with small arch and GroupNorm
model = MobileNetV3(
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12),
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 13
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 16, 56, 56])
assert feat[2].shape == torch.Size([1, 24, 28, 28])
assert feat[3].shape == torch.Size([1, 24, 28, 28])
assert feat[4].shape == torch.Size([1, 40, 14, 14])
assert feat[5].shape == torch.Size([1, 40, 14, 14])
assert feat[6].shape == torch.Size([1, 40, 14, 14])
assert feat[7].shape == torch.Size([1, 48, 14, 14])
assert feat[8].shape == torch.Size([1, 48, 14, 14])
assert feat[9].shape == torch.Size([1, 96, 7, 7])
assert feat[10].shape == torch.Size([1, 96, 7, 7])
assert feat[11].shape == torch.Size([1, 96, 7, 7])
assert feat[12].shape == torch.Size([1, 576, 7, 7])
# Test MobileNetV3 forward with large arch
model = MobileNetV3(
arch='large',
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 17
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 16, 112, 112])
assert feat[2].shape == torch.Size([1, 24, 56, 56])
assert feat[3].shape == torch.Size([1, 24, 56, 56])
assert feat[4].shape == torch.Size([1, 40, 28, 28])
assert feat[5].shape == torch.Size([1, 40, 28, 28])
assert feat[6].shape == torch.Size([1, 40, 28, 28])
assert feat[7].shape == torch.Size([1, 80, 14, 14])
assert feat[8].shape == torch.Size([1, 80, 14, 14])
assert feat[9].shape == torch.Size([1, 80, 14, 14])
assert feat[10].shape == torch.Size([1, 80, 14, 14])
assert feat[11].shape == torch.Size([1, 112, 14, 14])
assert feat[12].shape == torch.Size([1, 112, 14, 14])
assert feat[13].shape == torch.Size([1, 160, 7, 7])
assert feat[14].shape == torch.Size([1, 160, 7, 7])
assert feat[15].shape == torch.Size([1, 160, 7, 7])
assert feat[16].shape == torch.Size([1, 960, 7, 7])
# Test MobileNetV3 forward with large arch
model = MobileNetV3(arch='large', out_indices=(0, ))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 16, 112, 112])
# Test MobileNetV3 with checkpoint forward
model = MobileNetV3(with_cp=True)
for m in model.modules():
if isinstance(m, InvertedResidual):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 576, 7, 7])
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import pytest
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from torch import nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import MobileOne
from mmpretrain.models.backbones.mobileone import MobileOneBlock
from mmpretrain.models.utils import SELayer
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def is_mobileone_block(modules):
if isinstance(modules, MobileOneBlock):
return True
return False
def test_mobileoneblock():
# Test MobileOneBlock with kernel_size 3
block = MobileOneBlock(5, 10, 3, 1, stride=1, groups=5)
block.eval()
x = torch.randn(1, 5, 16, 16)
y = block(x)
assert block.branch_norm is None
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (3, 3)
assert block.branch_conv_list[0].conv.groups == 5
assert block.se_cfg is None
assert y.shape == torch.Size((1, 10, 16, 16))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 5
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 10, 16, 16))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# Test MobileOneBlock with num_con = 4
block = MobileOneBlock(5, 10, 3, 4, stride=1, groups=5)
block.eval()
x = torch.randn(1, 5, 16, 16)
y = block(x)
assert block.branch_norm is None
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (3, 3)
assert block.branch_conv_list[0].conv.groups == 5
assert len(block.branch_conv_list) == 4
assert block.se_cfg is None
assert y.shape == torch.Size((1, 10, 16, 16))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 5
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 10, 16, 16))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# Test MobileOneBlock with kernel_size 1
block = MobileOneBlock(5, 10, 1, 1, stride=1, padding=0)
block.eval()
x = torch.randn(1, 5, 16, 16)
y = block(x)
assert block.branch_norm is None
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (1, 1)
assert block.branch_conv_list[0].conv.groups == 1
assert len(block.branch_conv_list) == 1
assert block.se_cfg is None
assert y.shape == torch.Size((1, 10, 16, 16))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (1, 1)
assert block.branch_reparam.groups == 1
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 10, 16, 16))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# Test MobileOneBlock with stride = 2
block = MobileOneBlock(10, 10, 3, 4, stride=2, groups=10)
x = torch.randn(1, 10, 16, 16)
block.eval()
y = block(x)
assert block.branch_norm is None
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (3, 3)
assert block.branch_conv_list[0].conv.groups == 10
assert len(block.branch_conv_list) == 4
assert block.se_cfg is None
assert y.shape == torch.Size((1, 10, 8, 8))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 10
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 10, 8, 8))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# # Test MobileOneBlock with padding == dilation == 2
block = MobileOneBlock(
10, 10, 3, 4, stride=1, groups=10, padding=2, dilation=2)
x = torch.randn(1, 10, 16, 16)
block.eval()
y = block(x)
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (3, 3)
assert block.branch_conv_list[0].conv.groups == 10
assert len(block.branch_conv_list) == 4
assert block.se_cfg is None
assert y.shape == torch.Size((1, 10, 16, 16))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 10
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 10, 16, 16))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# Test MobileOneBlock with se
se_cfg = dict(ratio=4, divisor=1)
block = MobileOneBlock(32, 32, 3, 4, stride=1, se_cfg=se_cfg, groups=32)
x = torch.randn(1, 32, 16, 16)
block.eval()
y = block(x)
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_scale')
assert hasattr(block, 'branch_conv_list')
assert hasattr(block, 'branch_norm')
assert block.branch_conv_list[0].conv.kernel_size == (3, 3)
assert block.branch_conv_list[0].conv.groups == 32
assert len(block.branch_conv_list) == 4
assert isinstance(block.se, SELayer)
assert y.shape == torch.Size((1, 32, 16, 16))
block.switch_to_deploy()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 32
assert block.deploy is True
y_deploy = block(x)
assert y_deploy.shape == torch.Size((1, 32, 16, 16))
assert torch.allclose(y, y_deploy, atol=1e-5, rtol=1e-4)
# Test MobileOneBlock with deploy == True
se_cfg = dict(ratio=4, divisor=1)
block = MobileOneBlock(
32, 32, 3, 4, stride=1, se_cfg=se_cfg, groups=32, deploy=True)
x = torch.randn(1, 32, 16, 16)
block.eval()
assert hasattr(block, 'branch_reparam')
assert block.branch_reparam.kernel_size == (3, 3)
assert block.branch_reparam.groups == 32
assert isinstance(block.se, SELayer)
assert block.deploy is True
y = block(x)
assert y.shape == torch.Size((1, 32, 16, 16))
def test_mobileone_backbone():
with pytest.raises(TypeError):
# arch must be str or dict
MobileOne(arch=[4, 6, 16, 1])
with pytest.raises(AssertionError):
# arch must in arch_settings
MobileOne(arch='S3')
with pytest.raises(KeyError):
arch = dict(num_blocks=[2, 4, 14, 1])
MobileOne(arch=arch)
# Test len(arch['num_blocks']) == len(arch['width_factor'])
with pytest.raises(AssertionError):
arch = dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75],
num_conv_branches=[1, 1, 1, 1],
num_se_blocks=[0, 0, 5, 1])
MobileOne(arch=arch)
# Test max(out_indices) < len(arch['num_blocks'])
with pytest.raises(AssertionError):
MobileOne('s0', out_indices=dict())
# Test out_indices not type of int or Sequence
with pytest.raises(AssertionError):
MobileOne('s0', out_indices=(5, ))
# Test MobileOne norm state
model = MobileOne('s0')
model.train()
assert check_norm_state(model.modules(), True)
# Test MobileOne with first stage frozen
frozen_stages = 1
model = MobileOne('s0', frozen_stages=frozen_stages)
model.train()
for param in model.stage0.parameters():
assert param.requires_grad is False
for i in range(0, frozen_stages):
stage_name = model.stages[i]
stage = model.__getattr__(stage_name)
for mod in stage:
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in stage.parameters():
assert param.requires_grad is False
# Test MobileOne with norm_eval
model = MobileOne('s0', norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test MobileOne forward with layer 3 forward
model = MobileOne('s0', out_indices=(3, ))
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1024, 7, 7))
# Test MobileOne forward
arch_settings = {
's0': dict(out_channels=[48, 128, 256, 1024], ),
's1': dict(out_channels=[96, 192, 512, 1280]),
's2': dict(out_channels=[96, 256, 640, 2048]),
's3': dict(out_channels=[128, 320, 768, 2048], ),
's4': dict(out_channels=[192, 448, 896, 2048], )
}
choose_models = ['s0', 's1', 's4']
# Test RepVGG model forward
for model_name, model_arch in arch_settings.items():
if model_name not in choose_models:
continue
model = MobileOne(model_name, out_indices=(0, 1, 2, 3))
model.init_weights()
# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_arch['out_channels'][0], 56, 56))
assert feat[1].shape == torch.Size(
(1, model_arch['out_channels'][1], 28, 28))
assert feat[2].shape == torch.Size(
(1, model_arch['out_channels'][2], 14, 14))
assert feat[3].shape == torch.Size(
(1, model_arch['out_channels'][3], 7, 7))
# Test eval of "train" mode and "deploy" mode
gap = nn.AdaptiveAvgPool2d(output_size=(1))
fc = nn.Linear(model_arch['out_channels'][3], 10)
model.eval()
feat = model(imgs)
pred = fc(gap(feat[3]).flatten(1))
model.switch_to_deploy()
for m in model.modules():
if isinstance(m, MobileOneBlock):
assert m.deploy is True
feat_deploy = model(imgs)
pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
for i in range(4):
torch.allclose(feat[i], feat_deploy[i])
torch.allclose(pred, pred_deploy)
def test_load_deploy_mobileone():
# Test output before and load from deploy checkpoint
model = MobileOne('s0', out_indices=(0, 1, 2, 3))
inputs = torch.randn((1, 3, 224, 224))
tmpdir = tempfile.gettempdir()
ckpt_path = os.path.join(tmpdir, 'ckpt.pth')
model.switch_to_deploy()
model.eval()
outputs = model(inputs)
model_deploy = MobileOne('s0', out_indices=(0, 1, 2, 3), deploy=True)
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path)
outputs_load = model_deploy(inputs)
for feat, feat_load in zip(outputs, outputs_load):
assert torch.allclose(feat, feat_load)
os.remove(ckpt_path)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmpretrain.models.backbones import MobileViT
def test_assertion():
with pytest.raises(AssertionError):
MobileViT(arch='unknown')
with pytest.raises(AssertionError):
# MobileViT out_indices should be valid depth.
MobileViT(out_indices=-100)
def test_mobilevit():
# Test forward
model = MobileViT(arch='small')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 640, 8, 8])
# Test custom arch
model = MobileViT(arch=[
['mobilenetv2', 16, 1, 1, 2],
['mobilenetv2', 24, 2, 3, 2],
['mobilevit', 48, 2, 64, 128, 2, 2],
['mobilevit', 64, 2, 80, 160, 4, 2],
['mobilevit', 80, 2, 96, 192, 3, 2],
])
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 320, 8, 8])
# Test last_exp_factor
model = MobileViT(arch='small', last_exp_factor=8)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 1280, 8, 8])
# Test stem_channels
model = MobileViT(arch='small', stem_channels=32)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([1, 640, 8, 8])
# Test forward with multiple outputs
model = MobileViT(arch='small', out_indices=range(5))
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 5
assert feat[0].shape == torch.Size([1, 32, 128, 128])
assert feat[1].shape == torch.Size([1, 64, 64, 64])
assert feat[2].shape == torch.Size([1, 96, 32, 32])
assert feat[3].shape == torch.Size([1, 128, 16, 16])
assert feat[4].shape == torch.Size([1, 640, 8, 8])
# Test frozen_stages
model = MobileViT(arch='small', frozen_stages=2)
model.init_weights()
model.train()
for i in range(2):
assert not model.layers[i].training
for i in range(2, 5):
assert model.layers[i].training
# Copyright (c) OpenMMLab. All rights reserved.
import math
from copy import deepcopy
from unittest import TestCase
import torch
from mmpretrain.models import MViT
class TestMViT(TestCase):
def setUp(self):
self.cfg = dict(arch='tiny', drop_path_rate=0.1)
def test_structure(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
MViT(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}
MViT(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 96,
'num_layers': 10,
'num_heads': 1,
'downscale_indices': [2, 5, 8]
}
stage_indices = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
model = MViT(**cfg)
self.assertEqual(model.embed_dims, 96)
self.assertEqual(model.num_layers, 10)
for i, block in enumerate(model.blocks):
stage = stage_indices[i]
self.assertEqual(block.out_dims, 96 * 2**(stage))
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_scales'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
MViT(**cfg)
cfg['out_scales'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_scales 13'):
MViT(**cfg)
# Test model structure
cfg = deepcopy(self.cfg)
model = MViT(**cfg)
stage_indices = [0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]
self.assertEqual(len(model.blocks), 10)
dpr_inc = 0.1 / (10 - 1)
dpr = 0
for i, block in enumerate(model.blocks):
stage = stage_indices[i]
print(i, stage)
self.assertEqual(block.attn.num_heads, 2**stage)
if dpr > 0:
self.assertAlmostEqual(block.drop_path.drop_prob, dpr)
dpr += dpr_inc
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
cfg['use_abs_pos_embed'] = True
model = MViT(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
# The pos_embed is all zero before initialize
self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.)))
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.)))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = MViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 7, 7))
# Test forward with multi out scales
cfg = deepcopy(self.cfg)
cfg['out_scales'] = (0, 1, 2, 3)
model = MViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for stage, out in enumerate(outs):
stride = 2**stage
self.assertEqual(out.shape,
(1, 96 * stride, 56 // stride, 56 // stride))
# Test forward with dynamic input size
imgs1 = torch.randn(1, 3, 224, 224)
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
model = MViT(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 32),
math.ceil(imgs.shape[3] / 32))
self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape))
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmpretrain.models.backbones import PoolFormer
from mmpretrain.models.backbones.poolformer import PoolFormerBlock
class TestPoolFormer(TestCase):
def setUp(self):
arch = 's12'
self.cfg = dict(arch=arch, drop_path_rate=0.1)
self.arch = PoolFormer.arch_settings[arch]
def test_arch(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'Unavailable arch'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
PoolFormer(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'must have "layers"'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 96,
'num_heads': [3, 6, 12, 16],
}
PoolFormer(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
layers = [2, 2, 4, 2]
embed_dims = [6, 12, 6, 12]
mlp_ratios = [2, 3, 4, 4]
layer_scale_init_value = 1e-4
cfg['arch'] = dict(
layers=layers,
embed_dims=embed_dims,
mlp_ratios=mlp_ratios,
layer_scale_init_value=layer_scale_init_value,
)
model = PoolFormer(**cfg)
for i, stage in enumerate(model.network):
if not isinstance(stage, PoolFormerBlock):
continue
self.assertEqual(len(stage), layers[i])
self.assertEqual(stage[0].mlp.fc1.in_channels, embed_dims[i])
self.assertEqual(stage[0].mlp.fc1.out_channels,
embed_dims[i] * mlp_ratios[i])
self.assertTrue(
torch.allclose(stage[0].layer_scale_1,
torch.tensor(layer_scale_init_value)))
self.assertTrue(
torch.allclose(stage[0].layer_scale_2,
torch.tensor(layer_scale_init_value)))
def test_init_weights(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = PoolFormer(**cfg)
ori_weight = model.patch_embed.proj.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed.proj.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = PoolFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
feat = outs[-1]
self.assertEqual(feat.shape, (1, 512, 7, 7))
# test multiple output indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = (0, 2, 4, 6)
model = PoolFormer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 4)
for dim, stride, out in zip(self.arch['embed_dims'], [1, 2, 4, 8],
outs):
self.assertEqual(out.shape, (1, dim, 56 // stride, 56 // stride))
def test_structure(self):
# test drop_path_rate decay
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.2
model = PoolFormer(**cfg)
layers = self.arch['layers']
for i, block in enumerate(model.network):
expect_prob = 0.2 / (sum(layers) - 1) * i
if hasattr(block, 'drop_path'):
if expect_prob == 0:
self.assertIsInstance(block.drop_path, torch.nn.Identity)
else:
self.assertAlmostEqual(block.drop_path.drop_prob,
expect_prob)
# test with first stage frozen.
cfg = deepcopy(self.cfg)
frozen_stages = 1
cfg['frozen_stages'] = frozen_stages
cfg['out_indices'] = (0, 2, 4, 6)
model = PoolFormer(**cfg)
model.init_weights()
model.train()
# the patch_embed and first stage should not require grad.
self.assertFalse(model.patch_embed.training)
for param in model.patch_embed.parameters():
self.assertFalse(param.requires_grad)
for i in range(frozen_stages):
module = model.network[i]
for param in module.parameters():
self.assertFalse(param.requires_grad)
for param in model.norm0.parameters():
self.assertFalse(param.requires_grad)
# the second stage should require grad.
for i in range(frozen_stages + 1, 7):
module = model.network[i]
for param in module.parameters():
self.assertTrue(param.requires_grad)
if hasattr(model, f'norm{i}'):
norm = getattr(model, f'norm{i}')
for param in norm.parameters():
self.assertTrue(param.requires_grad)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmpretrain.models.backbones import RegNet
regnet_test_data = [
('regnetx_400mf',
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22,
bot_mul=1.0), [32, 64, 160, 384]),
('regnetx_800mf',
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16,
bot_mul=1.0), [64, 128, 288, 672]),
('regnetx_1.6gf',
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18,
bot_mul=1.0), [72, 168, 408, 912]),
('regnetx_3.2gf',
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25,
bot_mul=1.0), [96, 192, 432, 1008]),
('regnetx_4.0gf',
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23,
bot_mul=1.0), [80, 240, 560, 1360]),
('regnetx_6.4gf',
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17,
bot_mul=1.0), [168, 392, 784, 1624]),
('regnetx_8.0gf',
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23,
bot_mul=1.0), [80, 240, 720, 1920]),
('regnetx_12gf',
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19,
bot_mul=1.0), [224, 448, 896, 2240]),
]
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
def test_regnet_backbone(arch_name, arch, out_channels):
with pytest.raises(AssertionError):
# ResNeXt depth should be in [50, 101, 152]
RegNet(arch_name + '233')
# output the last feature map
model = RegNet(arch_name)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == (1, out_channels[-1], 7, 7)
# output feature map of all stages
model = RegNet(arch_name, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, out_channels[0], 56, 56)
assert feat[1].shape == (1, out_channels[1], 28, 28)
assert feat[2].shape == (1, out_channels[2], 14, 14)
assert feat[3].shape == (1, out_channels[3], 7, 7)
@pytest.mark.parametrize('arch_name,arch,out_channels', regnet_test_data)
def test_custom_arch(arch_name, arch, out_channels):
# output the last feature map
model = RegNet(arch)
model.init_weights()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == (1, out_channels[-1], 7, 7)
# output feature map of all stages
model = RegNet(arch, out_indices=(0, 1, 2, 3))
model.init_weights()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, out_channels[0], 56, 56)
assert feat[1].shape == (1, out_channels[1], 28, 28)
assert feat[2].shape == (1, out_channels[2], 14, 14)
assert feat[3].shape == (1, out_channels[3], 7, 7)
def test_exception():
# arch must be a str or dict
with pytest.raises(TypeError):
_ = RegNet(50)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import pytest
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from torch import nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import RepLKNet
from mmpretrain.models.backbones.replknet import ReparamLargeKernelConv
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def is_replk_block(modules):
if isinstance(modules, ReparamLargeKernelConv):
return True
return False
def test_replknet_replkblock():
# Test ReparamLargeKernelConv with in_channels != out_channels,
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
block = ReparamLargeKernelConv(
5, 10, kernel_size=31, stride=1, groups=5, small_kernel=5)
block.eval()
x = torch.randn(1, 5, 64, 64)
x_out_not_deploy = block(x)
assert block.small_kernel <= block.kernel_size
assert not hasattr(block, 'lkb_reparam')
assert hasattr(block, 'lkb_origin')
assert hasattr(block, 'small_conv')
assert x_out_not_deploy.shape == torch.Size((1, 10, 64, 64))
block.merge_kernel()
assert block.small_kernel_merged is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 10, 64, 64))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test ReparamLargeKernelConv with in_channels == out_channels,
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 5
block = ReparamLargeKernelConv(
12, 12, kernel_size=31, stride=1, groups=12, small_kernel=5)
block.eval()
x = torch.randn(1, 12, 64, 64)
x_out_not_deploy = block(x)
assert block.small_kernel <= block.kernel_size
assert not hasattr(block, 'lkb_reparam')
assert hasattr(block, 'lkb_origin')
assert hasattr(block, 'small_conv')
assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
block.merge_kernel()
assert block.small_kernel_merged is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test ReparamLargeKernelConv with in_channels == out_channels,
# kernel_size = 31, stride = 2, groups=in_channels, small_kernel = 5
block = ReparamLargeKernelConv(
16, 16, kernel_size=31, stride=2, groups=16, small_kernel=5)
block.eval()
x = torch.randn(1, 16, 64, 64)
x_out_not_deploy = block(x)
assert block.small_kernel <= block.kernel_size
assert not hasattr(block, 'lkb_reparam')
assert hasattr(block, 'lkb_origin')
assert hasattr(block, 'small_conv')
assert x_out_not_deploy.shape == torch.Size((1, 16, 32, 32))
block.merge_kernel()
assert block.small_kernel_merged is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 16, 32, 32))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test ReparamLargeKernelConv with in_channels == out_channels,
# kernel_size = 27, stride = 1, groups=in_channels, small_kernel = 5
block = ReparamLargeKernelConv(
12, 12, kernel_size=27, stride=1, groups=12, small_kernel=5)
block.eval()
x = torch.randn(1, 12, 48, 48)
x_out_not_deploy = block(x)
assert block.small_kernel <= block.kernel_size
assert not hasattr(block, 'lkb_reparam')
assert hasattr(block, 'lkb_origin')
assert hasattr(block, 'small_conv')
assert x_out_not_deploy.shape == torch.Size((1, 12, 48, 48))
block.merge_kernel()
assert block.small_kernel_merged is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 12, 48, 48))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test ReparamLargeKernelConv with in_channels == out_channels,
# kernel_size = 31, stride = 1, groups=in_channels, small_kernel = 7
block = ReparamLargeKernelConv(
12, 12, kernel_size=31, stride=1, groups=12, small_kernel=7)
block.eval()
x = torch.randn(1, 12, 64, 64)
x_out_not_deploy = block(x)
assert block.small_kernel <= block.kernel_size
assert not hasattr(block, 'lkb_reparam')
assert hasattr(block, 'lkb_origin')
assert hasattr(block, 'small_conv')
assert x_out_not_deploy.shape == torch.Size((1, 12, 64, 64))
block.merge_kernel()
assert block.small_kernel_merged is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 12, 64, 64))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test ReparamLargeKernelConv with deploy == True
block = ReparamLargeKernelConv(
8,
8,
kernel_size=31,
stride=1,
groups=8,
small_kernel=5,
small_kernel_merged=True)
assert isinstance(block.lkb_reparam, nn.Conv2d)
assert not hasattr(block, 'lkb_origin')
assert not hasattr(block, 'small_conv')
x = torch.randn(1, 8, 48, 48)
x_out = block(x)
assert x_out.shape == torch.Size((1, 8, 48, 48))
def test_replknet_backbone():
with pytest.raises(TypeError):
# arch must be str or dict
RepLKNet(arch=[4, 6, 16, 1])
with pytest.raises(AssertionError):
# arch must in arch_settings
RepLKNet(arch='31C')
with pytest.raises(KeyError):
# arch must have num_blocks and width_factor
arch = dict(large_kernel_sizes=[31, 29, 27, 13])
RepLKNet(arch=arch)
with pytest.raises(KeyError):
# arch must have num_blocks and width_factor
arch = dict(large_kernel_sizes=[31, 29, 27, 13], layers=[2, 2, 18, 2])
RepLKNet(arch=arch)
with pytest.raises(KeyError):
# arch must have num_blocks and width_factor
arch = dict(
large_kernel_sizes=[31, 29, 27, 13],
layers=[2, 2, 18, 2],
channels=[128, 256, 512, 1024])
RepLKNet(arch=arch)
# len(arch['large_kernel_sizes']) == arch['layers'])
# == len(arch['channels'])
# == len(strides) == len(dilations)
with pytest.raises(AssertionError):
arch = dict(
large_kernel_sizes=[31, 29, 27, 13],
layers=[2, 2, 18, 2],
channels=[128, 256, 1024],
small_kernel=5,
dw_ratio=1)
RepLKNet(arch=arch)
# len(strides) must equal to 4
with pytest.raises(AssertionError):
RepLKNet('31B', strides=(2, 2, 2))
# len(dilations) must equal to 4
with pytest.raises(AssertionError):
RepLKNet('31B', strides=(2, 2, 2, 2), dilations=(1, 1, 1))
# max(out_indices) < len(arch['num_blocks'])
with pytest.raises(AssertionError):
RepLKNet('31B', out_indices=(5, ))
# Test RepLKNet norm state
model = RepLKNet('31B')
model.train()
assert check_norm_state(model.modules(), True)
# Test RepLKNet with first stage frozen
frozen_stages = 1
model = RepLKNet('31B', frozen_stages=frozen_stages)
model.train()
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(0, frozen_stages):
stage = model.stages[i]
for mod in stage.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in stage.parameters():
assert param.requires_grad is False
# Test RepLKNet with norm_eval
model = RepLKNet('31B', norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test RepLKNet forward with layer 3 forward
model = RepLKNet('31B', out_indices=(3, ))
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1024, 7, 7))
# Test RepLKNet forward
model_test_settings = [
dict(model_name='31B', out_sizes=(128, 256, 512, 1024)),
# dict(model_name='31L', out_sizes=(192, 384, 768, 1536)),
# dict(model_name='XL', out_sizes=(256, 512, 1024, 2048))
]
choose_models = ['31B']
# Test RepLKNet model forward
for model_test_setting in model_test_settings:
if model_test_setting['model_name'] not in choose_models:
continue
model = RepLKNet(
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
model.init_weights()
# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_test_setting['out_sizes'][0], 56, 56))
assert feat[1].shape == torch.Size(
(1, model_test_setting['out_sizes'][1], 28, 28))
assert feat[2].shape == torch.Size(
(1, model_test_setting['out_sizes'][2], 14, 14))
assert feat[3].shape == torch.Size(
(1, model_test_setting['out_sizes'][3], 7, 7))
# Test eval of "train" mode and "deploy" mode
gap = nn.AdaptiveAvgPool2d(output_size=(1))
fc = nn.Linear(model_test_setting['out_sizes'][3], 10)
model.eval()
feat = model(imgs)
pred = fc(gap(feat[3]).flatten(1))
model.switch_to_deploy()
for m in model.modules():
if isinstance(m, ReparamLargeKernelConv):
assert m.small_kernel_merged is True
feat_deploy = model(imgs)
pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
for i in range(4):
torch.allclose(feat[i], feat_deploy[i])
torch.allclose(pred, pred_deploy)
def test_replknet_load():
# Test output before and load from deploy checkpoint
model = RepLKNet('31B', out_indices=(0, 1, 2, 3))
inputs = torch.randn((1, 3, 224, 224))
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
model.switch_to_deploy()
model.eval()
outputs = model(inputs)
model_deploy = RepLKNet(
'31B', out_indices=(0, 1, 2, 3), small_kernel_merged=True)
model_deploy.eval()
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path, strict=True)
outputs_load = model_deploy(inputs)
for feat, feat_load in zip(outputs, outputs_load):
assert torch.allclose(feat, feat_load)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
from copy import deepcopy
from unittest import TestCase
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from mmpretrain.models.backbones import RepMLPNet
class TestRepMLP(TestCase):
def setUp(self):
# default model setting
self.cfg = dict(
arch='b',
img_size=224,
out_indices=(3, ),
reparam_conv_kernels=(1, 3),
final_norm=True)
# default model setting and output stage channels
self.model_forward_settings = [
dict(model_name='B', out_sizes=(96, 192, 384, 768)),
]
# temp ckpt path
self.ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
def test_arch(self):
# Test invalid arch data type
with self.assertRaisesRegex(AssertionError, 'arch needs a dict'):
cfg = deepcopy(self.cfg)
cfg['arch'] = [96, 192, 384, 768]
RepMLPNet(**cfg)
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'A'
RepMLPNet(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'channels': [96, 192, 384, 768],
'depths': [2, 2, 12, 2]
}
RepMLPNet(**cfg)
# test len(arch['depths']) equals to len(arch['channels'])
# equals to len(arch['sharesets_nums'])
with self.assertRaisesRegex(AssertionError, 'Length of setting'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'channels': [96, 192, 384, 768],
'depths': [2, 2, 12, 2],
'sharesets_nums': [1, 4, 32]
}
RepMLPNet(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
channels = [96, 192, 384, 768]
depths = [2, 2, 12, 2]
sharesets_nums = [1, 4, 32, 128]
cfg['arch'] = {
'channels': channels,
'depths': depths,
'sharesets_nums': sharesets_nums
}
cfg['out_indices'] = (0, 1, 2, 3)
model = RepMLPNet(**cfg)
for i, stage in enumerate(model.stages):
self.assertEqual(len(stage), depths[i])
self.assertEqual(stage[0].repmlp_block.channels, channels[i])
self.assertEqual(stage[0].repmlp_block.deploy, False)
self.assertEqual(stage[0].repmlp_block.num_sharesets,
sharesets_nums[i])
def test_init(self):
# test weight init cfg
cfg = deepcopy(self.cfg)
cfg['init_cfg'] = [
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
]
model = RepMLPNet(**cfg)
ori_weight = model.patch_embed.projection.weight.clone().detach()
model.init_weights()
initialized_weight = model.patch_embed.projection.weight
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = RepMLPNet(**cfg)
feat = model(imgs)
self.assertTrue(isinstance(feat, tuple))
self.assertEqual(len(feat), 1)
self.assertTrue(isinstance(feat[0], torch.Tensor))
self.assertEqual(feat[0].shape, torch.Size((1, 768, 7, 7)))
imgs = torch.randn(1, 3, 256, 256)
with self.assertRaisesRegex(AssertionError, "doesn't support dynamic"):
model(imgs)
# Test RepMLPNet model forward
for model_test_setting in self.model_forward_settings:
model = RepMLPNet(
model_test_setting['model_name'],
out_indices=(0, 1, 2, 3),
final_norm=False)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
self.assertEqual(
feat[0].shape,
torch.Size((1, model_test_setting['out_sizes'][1], 28, 28)))
self.assertEqual(
feat[1].shape,
torch.Size((1, model_test_setting['out_sizes'][2], 14, 14)))
self.assertEqual(
feat[2].shape,
torch.Size((1, model_test_setting['out_sizes'][3], 7, 7)))
self.assertEqual(
feat[3].shape,
torch.Size((1, model_test_setting['out_sizes'][3], 7, 7)))
def test_deploy_(self):
# Test output before and load from deploy checkpoint
imgs = torch.randn((1, 3, 224, 224))
cfg = dict(
arch='b', out_indices=(
1,
3,
), reparam_conv_kernels=(1, 3, 5))
model = RepMLPNet(**cfg)
model.eval()
feats = model(imgs)
model.switch_to_deploy()
for m in model.modules():
if hasattr(m, 'deploy'):
self.assertTrue(m.deploy)
model.eval()
feats_ = model(imgs)
assert len(feats) == len(feats_)
for i in range(len(feats)):
self.assertTrue(
torch.allclose(
feats[i].sum(), feats_[i].sum(), rtol=0.1, atol=0.1))
cfg['deploy'] = True
model_deploy = RepMLPNet(**cfg)
model_deploy.eval()
save_checkpoint(model.state_dict(), self.ckpt_path)
load_checkpoint(model_deploy, self.ckpt_path, strict=True)
feats__ = model_deploy(imgs)
assert len(feats_) == len(feats__)
for i in range(len(feats)):
self.assertTrue(
torch.allclose(feats__[i], feats_[i], rtol=0.01, atol=0.01))
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import pytest
import torch
from mmengine.runner import load_checkpoint, save_checkpoint
from torch import nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import RepVGG
from mmpretrain.models.backbones.repvgg import RepVGGBlock
from mmpretrain.models.utils import SELayer
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def is_repvgg_block(modules):
if isinstance(modules, RepVGGBlock):
return True
return False
def test_repvgg_repvggblock():
# Test RepVGGBlock with in_channels != out_channels, stride = 1
block = RepVGGBlock(5, 10, stride=1)
block.eval()
x = torch.randn(1, 5, 16, 16)
x_out_not_deploy = block(x)
assert block.branch_norm is None
assert not hasattr(block, 'branch_reparam')
assert hasattr(block, 'branch_1x1')
assert hasattr(block, 'branch_3x3')
assert hasattr(block, 'branch_norm')
assert block.se_cfg is None
assert x_out_not_deploy.shape == torch.Size((1, 10, 16, 16))
block.switch_to_deploy()
assert block.deploy is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 10, 16, 16))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test RepVGGBlock with in_channels == out_channels, stride = 1
block = RepVGGBlock(12, 12, stride=1)
block.eval()
x = torch.randn(1, 12, 8, 8)
x_out_not_deploy = block(x)
assert isinstance(block.branch_norm, nn.BatchNorm2d)
assert not hasattr(block, 'branch_reparam')
assert x_out_not_deploy.shape == torch.Size((1, 12, 8, 8))
block.switch_to_deploy()
assert block.deploy is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 12, 8, 8))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test RepVGGBlock with in_channels == out_channels, stride = 2
block = RepVGGBlock(16, 16, stride=2)
block.eval()
x = torch.randn(1, 16, 8, 8)
x_out_not_deploy = block(x)
assert block.branch_norm is None
assert x_out_not_deploy.shape == torch.Size((1, 16, 4, 4))
block.switch_to_deploy()
assert block.deploy is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 16, 4, 4))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test RepVGGBlock with padding == dilation == 2
block = RepVGGBlock(14, 14, stride=1, padding=2, dilation=2)
block.eval()
x = torch.randn(1, 14, 16, 16)
x_out_not_deploy = block(x)
assert isinstance(block.branch_norm, nn.BatchNorm2d)
assert x_out_not_deploy.shape == torch.Size((1, 14, 16, 16))
block.switch_to_deploy()
assert block.deploy is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 14, 16, 16))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test RepVGGBlock with groups = 2
block = RepVGGBlock(4, 4, stride=1, groups=2)
block.eval()
x = torch.randn(1, 4, 5, 6)
x_out_not_deploy = block(x)
assert x_out_not_deploy.shape == torch.Size((1, 4, 5, 6))
block.switch_to_deploy()
assert block.deploy is True
x_out_deploy = block(x)
assert x_out_deploy.shape == torch.Size((1, 4, 5, 6))
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
# Test RepVGGBlock with se
se_cfg = dict(ratio=4, divisor=1)
block = RepVGGBlock(18, 18, stride=1, se_cfg=se_cfg)
block.train()
x = torch.randn(1, 18, 5, 5)
x_out_not_deploy = block(x)
assert isinstance(block.se_layer, SELayer)
assert x_out_not_deploy.shape == torch.Size((1, 18, 5, 5))
# Test RepVGGBlock with checkpoint forward
block = RepVGGBlock(24, 24, stride=1, with_cp=True)
assert block.with_cp
x = torch.randn(1, 24, 7, 7)
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 7, 7))
# Test RepVGGBlock with deploy == True
block = RepVGGBlock(8, 8, stride=1, deploy=True)
assert isinstance(block.branch_reparam, nn.Conv2d)
assert not hasattr(block, 'branch_3x3')
assert not hasattr(block, 'branch_1x1')
assert not hasattr(block, 'branch_norm')
x = torch.randn(1, 8, 16, 16)
x_out = block(x)
assert x_out.shape == torch.Size((1, 8, 16, 16))
def test_repvgg_backbone():
with pytest.raises(TypeError):
# arch must be str or dict
RepVGG(arch=[4, 6, 16, 1])
with pytest.raises(AssertionError):
# arch must in arch_settings
RepVGG(arch='A3')
with pytest.raises(KeyError):
# arch must have num_blocks and width_factor
arch = dict(num_blocks=[2, 4, 14, 1])
RepVGG(arch=arch)
# len(arch['num_blocks']) == len(arch['width_factor'])
# == len(strides) == len(dilations)
with pytest.raises(AssertionError):
arch = dict(num_blocks=[2, 4, 14, 1], width_factor=[0.75, 0.75, 0.75])
RepVGG(arch=arch)
# len(strides) must equal to 4
with pytest.raises(AssertionError):
RepVGG('A0', strides=(1, 1, 1))
# len(dilations) must equal to 4
with pytest.raises(AssertionError):
RepVGG('A0', strides=(1, 1, 1, 1), dilations=(1, 1, 2))
# max(out_indices) < len(arch['num_blocks'])
with pytest.raises(AssertionError):
RepVGG('A0', out_indices=(5, ))
# max(arch['group_idx'].keys()) <= sum(arch['num_blocks'])
with pytest.raises(AssertionError):
arch = dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75],
group_idx={22: 2})
RepVGG(arch=arch)
# Test RepVGG norm state
model = RepVGG('A0')
model.train()
assert check_norm_state(model.modules(), True)
# Test RepVGG with first stage frozen
frozen_stages = 1
model = RepVGG('A0', frozen_stages=frozen_stages)
model.train()
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(0, frozen_stages):
stage_name = model.stages[i]
stage = model.__getattr__(stage_name)
for mod in stage:
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in stage.parameters():
assert param.requires_grad is False
# Test RepVGG with norm_eval
model = RepVGG('A0', norm_eval=True)
model.train()
assert check_norm_state(model.modules(), False)
# Test RepVGG forward with layer 3 forward
model = RepVGG('A0', out_indices=(3, ))
model.init_weights()
model.eval()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 1, 1))
# Test with custom arch
cfg = dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=16)
model = RepVGG(arch=cfg, out_indices=(3, ))
model.eval()
assert model.stem.out_channels == min(16, 64 * 1)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 512, 1, 1))
# Test RepVGG forward
model_test_settings = [
dict(model_name='A0', out_sizes=(48, 96, 192, 1280)),
dict(model_name='A1', out_sizes=(64, 128, 256, 1280)),
dict(model_name='A2', out_sizes=(96, 192, 384, 1408)),
dict(model_name='B0', out_sizes=(64, 128, 256, 1280)),
dict(model_name='B1', out_sizes=(128, 256, 512, 2048)),
dict(model_name='B1g2', out_sizes=(128, 256, 512, 2048)),
dict(model_name='B1g4', out_sizes=(128, 256, 512, 2048)),
dict(model_name='B2', out_sizes=(160, 320, 640, 2560)),
dict(model_name='B2g2', out_sizes=(160, 320, 640, 2560)),
dict(model_name='B2g4', out_sizes=(160, 320, 640, 2560)),
dict(model_name='B3', out_sizes=(192, 384, 768, 2560)),
dict(model_name='B3g2', out_sizes=(192, 384, 768, 2560)),
dict(model_name='B3g4', out_sizes=(192, 384, 768, 2560)),
dict(model_name='D2se', out_sizes=(160, 320, 640, 2560))
]
choose_models = ['A0', 'B1', 'B1g2']
# Test RepVGG model forward
for model_test_setting in model_test_settings:
if model_test_setting['model_name'] not in choose_models:
continue
model = RepVGG(
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
model.init_weights()
model.eval()
# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_test_setting['out_sizes'][0], 8, 8))
assert feat[1].shape == torch.Size(
(1, model_test_setting['out_sizes'][1], 4, 4))
assert feat[2].shape == torch.Size(
(1, model_test_setting['out_sizes'][2], 2, 2))
assert feat[3].shape == torch.Size(
(1, model_test_setting['out_sizes'][3], 1, 1))
# Test eval of "train" mode and "deploy" mode
gap = nn.AdaptiveAvgPool2d(output_size=(1))
fc = nn.Linear(model_test_setting['out_sizes'][3], 10)
model.eval()
feat = model(imgs)
pred = fc(gap(feat[3]).flatten(1))
model.switch_to_deploy()
for m in model.modules():
if isinstance(m, RepVGGBlock):
assert m.deploy is True
feat_deploy = model(imgs)
pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
for i in range(4):
torch.allclose(feat[i], feat_deploy[i])
torch.allclose(pred, pred_deploy)
# Test RepVGG forward with add_ppf
model = RepVGG('A0', out_indices=(3, ), add_ppf=True)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))
# Test RepVGG forward with 'stem_channels' not in arch
arch = dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75, 2.5],
group_layer_map=None,
se_cfg=None)
model = RepVGG(arch, add_ppf=True)
model.stem.in_channels = min(64, 64 * 0.75)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))
def test_repvgg_load():
# Test output before and load from deploy checkpoint
model = RepVGG('A1', out_indices=(0, 1, 2, 3))
inputs = torch.randn((1, 3, 32, 32))
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
model.switch_to_deploy()
model.eval()
outputs = model(inputs)
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
model_deploy.eval()
save_checkpoint(model.state_dict(), ckpt_path)
load_checkpoint(model_deploy, ckpt_path, strict=True)
outputs_load = model_deploy(inputs)
for feat, feat_load in zip(outputs, outputs_load):
assert torch.allclose(feat, feat_load)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.models.backbones import Res2Net
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_resnet_cifar():
# Only support depth 50, 101 and 152
with pytest.raises(KeyError):
Res2Net(depth=18)
# test the feature map size when depth is 50
# and deep_stem=True, avg_down=True
model = Res2Net(
depth=50, out_indices=(0, 1, 2, 3), deep_stem=True, avg_down=True)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# test the feature map size when depth is 101
# and deep_stem=False, avg_down=False
model = Res2Net(
depth=101, out_indices=(0, 1, 2, 3), deep_stem=False, avg_down=False)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model.conv1(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test Res2Net with first stage frozen
frozen_stages = 1
model = Res2Net(depth=50, frozen_stages=frozen_stages, deep_stem=False)
model.init_weights()
model.train()
assert check_norm_state([model.norm1], False)
for param in model.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmpretrain.models.backbones import ResNeSt
from mmpretrain.models.backbones.resnest import Bottleneck as BottleneckS
def test_bottleneck():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
BottleneckS(64, 64, radix=2, reduction_factor=4, style='tensorflow')
# Test ResNeSt Bottleneck structure
block = BottleneckS(
64, 256, radix=2, reduction_factor=4, stride=2, style='pytorch')
assert block.avd_layer.stride == 2
assert block.conv2.channels == 64
# Test ResNeSt Bottleneck forward
block = BottleneckS(64, 64, radix=2, reduction_factor=4)
x = torch.randn(2, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([2, 64, 56, 56])
def test_resnest():
with pytest.raises(KeyError):
# ResNeSt depth should be in [50, 101, 152, 200]
ResNeSt(depth=18)
# Test ResNeSt with radix 2, reduction_factor 4
model = ResNeSt(
depth=50, radix=2, reduction_factor=4, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([2, 256, 56, 56])
assert feat[1].shape == torch.Size([2, 512, 28, 28])
assert feat[2].shape == torch.Size([2, 1024, 14, 14])
assert feat[3].shape == torch.Size([2, 2048, 7, 7])
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.models.backbones import ResNet, ResNetV1c, ResNetV1d
from mmpretrain.models.backbones.resnet import (BasicBlock, Bottleneck,
ResLayer, get_expansion)
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (BasicBlock, Bottleneck)):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_get_expansion():
assert get_expansion(Bottleneck, 2) == 2
assert get_expansion(BasicBlock) == 1
assert get_expansion(Bottleneck) == 4
class MyResBlock(nn.Module):
expansion = 8
assert get_expansion(MyResBlock) == 8
# expansion must be an integer or None
with pytest.raises(TypeError):
get_expansion(Bottleneck, '0')
# expansion is not specified and cannot be inferred
with pytest.raises(TypeError):
class SomeModule(nn.Module):
pass
get_expansion(SomeModule)
def test_basic_block():
# expansion must be 1
with pytest.raises(AssertionError):
BasicBlock(64, 64, expansion=2)
# BasicBlock with stride 1, out_channels == in_channels
block = BasicBlock(64, 64)
assert block.in_channels == 64
assert block.mid_channels == 64
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 64
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (1, 1)
assert block.conv2.in_channels == 64
assert block.conv2.out_channels == 64
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# BasicBlock with stride 1 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128))
block = BasicBlock(64, 128, downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 128
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 128
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (1, 1)
assert block.conv2.in_channels == 128
assert block.conv2.out_channels == 128
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 128, 56, 56])
# BasicBlock with stride 2 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(128))
block = BasicBlock(64, 128, stride=2, downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 128
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 128
assert block.conv1.kernel_size == (3, 3)
assert block.conv1.stride == (2, 2)
assert block.conv2.in_channels == 128
assert block.conv2.out_channels == 128
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 128, 28, 28])
# forward with checkpointing
block = BasicBlock(64, 64, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_bottleneck():
# style must be in ['pytorch', 'caffe']
with pytest.raises(AssertionError):
Bottleneck(64, 64, style='tensorflow')
# expansion must be divisible by out_channels
with pytest.raises(AssertionError):
Bottleneck(64, 64, expansion=3)
# Test Bottleneck style
block = Bottleneck(64, 64, stride=2, style='pytorch')
assert block.conv1.stride == (1, 1)
assert block.conv2.stride == (2, 2)
block = Bottleneck(64, 64, stride=2, style='caffe')
assert block.conv1.stride == (2, 2)
assert block.conv2.stride == (1, 1)
# Bottleneck with stride 1
block = Bottleneck(64, 64, style='pytorch')
assert block.in_channels == 64
assert block.mid_channels == 16
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 16
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 16
assert block.conv2.out_channels == 16
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 16
assert block.conv3.out_channels == 64
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 56, 56)
# Bottleneck with stride 1 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1), nn.BatchNorm2d(128))
block = Bottleneck(64, 128, style='pytorch', downsample=downsample)
assert block.in_channels == 64
assert block.mid_channels == 32
assert block.out_channels == 128
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 32
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 32
assert block.conv2.out_channels == 32
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 32
assert block.conv3.out_channels == 128
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 128, 56, 56)
# Bottleneck with stride 2 and downsample
downsample = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=2), nn.BatchNorm2d(128))
block = Bottleneck(
64, 128, stride=2, style='pytorch', downsample=downsample)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 128, 28, 28)
# Bottleneck with expansion 2
block = Bottleneck(64, 64, style='pytorch', expansion=2)
assert block.in_channels == 64
assert block.mid_channels == 32
assert block.out_channels == 64
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 32
assert block.conv1.kernel_size == (1, 1)
assert block.conv2.in_channels == 32
assert block.conv2.out_channels == 32
assert block.conv2.kernel_size == (3, 3)
assert block.conv3.in_channels == 32
assert block.conv3.out_channels == 64
assert block.conv3.kernel_size == (1, 1)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == (1, 64, 56, 56)
# Test Bottleneck with checkpointing
block = Bottleneck(64, 64, with_cp=True)
block.train()
assert block.with_cp
x = torch.randn(1, 64, 56, 56, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_basicblock_reslayer():
# 3 BasicBlock w/o downsample
layer = ResLayer(BasicBlock, 3, 32, 32)
assert len(layer) == 3
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 32, 56, 56)
# 3 BasicBlock w/ stride 1 and downsample
layer = ResLayer(BasicBlock, 3, 32, 64)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (1, 1)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 56, 56)
# 3 BasicBlock w/ stride 2 and downsample
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 BasicBlock w/ stride 2 and downsample with avg pool
layer = ResLayer(BasicBlock, 3, 32, 64, stride=2, avg_down=True)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
assert layer[0].downsample[0].stride == 2
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
def test_bottleneck_reslayer():
# 3 Bottleneck w/o downsample
layer = ResLayer(Bottleneck, 3, 32, 32)
assert len(layer) == 3
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 32, 56, 56)
# 3 Bottleneck w/ stride 1 and downsample
layer = ResLayer(Bottleneck, 3, 32, 64)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 1
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (1, 1)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 56, 56)
# 3 Bottleneck w/ stride 2 and downsample
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 2
assert isinstance(layer[0].downsample[0], nn.Conv2d)
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 Bottleneck w/ stride 2 and downsample with avg pool
layer = ResLayer(Bottleneck, 3, 32, 64, stride=2, avg_down=True)
assert len(layer) == 3
assert layer[0].in_channels == 32
assert layer[0].out_channels == 64
assert layer[0].stride == 2
assert layer[0].conv1.out_channels == 16
assert layer[0].downsample is not None and len(layer[0].downsample) == 3
assert isinstance(layer[0].downsample[0], nn.AvgPool2d)
assert layer[0].downsample[0].stride == 2
for i in range(1, 3):
assert layer[i].in_channels == 64
assert layer[i].out_channels == 64
assert layer[i].conv1.out_channels == 16
assert layer[i].stride == 1
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 64, 28, 28)
# 3 Bottleneck with custom expansion
layer = ResLayer(Bottleneck, 3, 32, 32, expansion=2)
assert len(layer) == 3
for i in range(3):
assert layer[i].in_channels == 32
assert layer[i].out_channels == 32
assert layer[i].stride == 1
assert layer[i].conv1.out_channels == 16
assert layer[i].downsample is None
x = torch.randn(1, 32, 56, 56)
x_out = layer(x)
assert x_out.shape == (1, 32, 56, 56)
def test_resnet():
"""Test resnet backbone."""
with pytest.raises(KeyError):
# ResNet depth should be in [18, 34, 50, 101, 152]
ResNet(20)
with pytest.raises(AssertionError):
# In ResNet: 1 <= num_stages <= 4
ResNet(50, num_stages=0)
with pytest.raises(AssertionError):
# In ResNet: 1 <= num_stages <= 4
ResNet(50, num_stages=5)
with pytest.raises(AssertionError):
# len(strides) == len(dilations) == num_stages
ResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
with pytest.raises(TypeError):
# pretrained must be a string path
model = ResNet(50)
model.init_weights(pretrained=0)
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
ResNet(50, style='tensorflow')
# Test ResNet50 norm_eval=True
model = ResNet(50, norm_eval=True)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test ResNet50 with torchvision pretrained weight
model = ResNet(
depth=50,
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'))
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test ResNet50 with first stage frozen
frozen_stages = 1
model = ResNet(50, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False
for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test ResNet18 forward
model = ResNet(18, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 64, 56, 56)
assert feat[1].shape == (1, 128, 28, 28)
assert feat[2].shape == (1, 256, 14, 14)
assert feat[3].shape == (1, 512, 7, 7)
# Test ResNet50 with BatchNorm forward
model = ResNet(50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50 with DropPath forward
model = ResNet(50, out_indices=(0, 1, 2, 3), drop_path_rate=0.5)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50 with layers 1, 2, 3 out forward
model = ResNet(50, out_indices=(0, 1, 2))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
# Test ResNet50 with layers 3 (top feature maps) out forward
model = ResNet(50, out_indices=(3, ))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == (1, 2048, 7, 7)
# Test ResNet50 with checkpoint forward
model = ResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# zero initialization of residual blocks
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
elif isinstance(m, BasicBlock):
assert all_zeros(m.norm2)
# non-zero initialization of residual blocks
model = ResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=False)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert not all_zeros(m.norm3)
elif isinstance(m, BasicBlock):
assert not all_zeros(m.norm2)
def test_resnet_v1c():
model = ResNetV1c(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50V1d with first stage frozen
frozen_stages = 1
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
model.init_weights()
model.train()
check_norm_state(model.stem, False)
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
def test_resnet_v1d():
model = ResNetV1d(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
imgs = torch.randn(1, 3, 224, 224)
feat = model.stem(imgs)
assert feat.shape == (1, 64, 112, 112)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 56, 56)
assert feat[1].shape == (1, 512, 28, 28)
assert feat[2].shape == (1, 1024, 14, 14)
assert feat[3].shape == (1, 2048, 7, 7)
# Test ResNet50V1d with first stage frozen
frozen_stages = 1
model = ResNetV1d(depth=50, frozen_stages=frozen_stages)
assert len(model.stem) == 3
for i in range(3):
assert isinstance(model.stem[i], ConvModule)
model.init_weights()
model.train()
check_norm_state(model.stem, False)
for param in model.stem.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
def test_resnet_half_channel():
model = ResNet(50, base_channels=32, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 128, 56, 56)
assert feat[1].shape == (1, 256, 28, 28)
assert feat[2].shape == (1, 512, 14, 14)
assert feat[3].shape == (1, 1024, 7, 7)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.models.backbones import ResNet_CIFAR
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_resnet_cifar():
# deep_stem must be False
with pytest.raises(AssertionError):
ResNet_CIFAR(depth=18, deep_stem=True)
# test the feature map size when depth is 18
model = ResNet_CIFAR(depth=18, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 32)
feat = model.conv1(imgs)
assert feat.shape == (1, 64, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 64, 32, 32)
assert feat[1].shape == (1, 128, 16, 16)
assert feat[2].shape == (1, 256, 8, 8)
assert feat[3].shape == (1, 512, 4, 4)
# test the feature map size when depth is 50
model = ResNet_CIFAR(depth=50, out_indices=(0, 1, 2, 3))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 32)
feat = model.conv1(imgs)
assert feat.shape == (1, 64, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == (1, 256, 32, 32)
assert feat[1].shape == (1, 512, 16, 16)
assert feat[2].shape == (1, 1024, 8, 8)
assert feat[3].shape == (1, 2048, 4, 4)
# Test ResNet_CIFAR with first stage frozen
frozen_stages = 1
model = ResNet_CIFAR(depth=50, frozen_stages=frozen_stages)
model.init_weights()
model.train()
check_norm_state([model.norm1], False)
for param in model.conv1.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment