Unverified Commit a4c37026 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

add initializers and BaseModule for unified parameter initialization (#780)

* add initializers and BaseModule for unified parameter initialization

* fix circle import

* bug fix

* add is_init flag in BaseModule

* fix docstring

* sort import and fix doc format

* fix bug

* fix docformat and double quote string

* fix import sort

* import sort

* sort import

* revise according to comments

* fix doc format

* revise according to comments

* revise import and fix typo

* polish code

* revise minors

* revice minors

* revise apply function

* revise bias initialization with probability

* add type test for bias_prob

* revise minors
parent 11b92640
...@@ -13,9 +13,11 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -13,9 +13,11 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable # yapf: enable
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init, from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
fuse_conv_bn, get_model_complexity_info, kaiming_init, PretrainedInit, UniformInit, XavierInit,
normal_init, uniform_init, xavier_init) bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init)
from .vgg import VGG, make_vgg_layer from .vgg import VGG, make_vgg_layer
__all__ = [ __all__ = [
...@@ -30,5 +32,6 @@ __all__ = [ ...@@ -30,5 +32,6 @@ __all__ = [
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d' 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
] ]
...@@ -3,8 +3,6 @@ import logging ...@@ -3,8 +3,6 @@ import logging
import torch.nn as nn import torch.nn as nn
from ..runner import load_checkpoint
class AlexNet(nn.Module): class AlexNet(nn.Module):
"""AlexNet backbone. """AlexNet backbone.
...@@ -45,6 +43,7 @@ class AlexNet(nn.Module): ...@@ -45,6 +43,7 @@ class AlexNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
# use default initializer # use default initializer
......
...@@ -4,7 +4,6 @@ import logging ...@@ -4,7 +4,6 @@ import logging
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from ..runner import load_checkpoint
from .utils import constant_init, kaiming_init from .utils import constant_init, kaiming_init
...@@ -266,6 +265,7 @@ class ResNet(nn.Module): ...@@ -266,6 +265,7 @@ class ResNet(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (bias_init_with_prob, caffe2_xavier_init, from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
constant_init, kaiming_init, normal_init, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init) uniform_init, xavier_init)
__all__ = [ __all__ = [
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init', 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn' 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit'
] ]
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
INITIALIZERS = Registry('initializer')
def constant_init(module, val, bias=0): def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None: if hasattr(module, 'weight') and module.weight is not None:
...@@ -12,22 +16,25 @@ def constant_init(module, val, bias=0): ...@@ -12,22 +16,25 @@ def constant_init(module, val, bias=0):
def xavier_init(module, gain=1, bias=0, distribution='normal'): def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if distribution == 'uniform': if hasattr(module, 'weight') and module.weight is not None:
nn.init.xavier_uniform_(module.weight, gain=gain) if distribution == 'uniform':
else: nn.init.xavier_uniform_(module.weight, gain=gain)
nn.init.xavier_normal_(module.weight, gain=gain) else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0): def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std) if hasattr(module, 'weight') and module.weight is not None:
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0): def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b) if hasattr(module, 'weight') and module.weight is not None:
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
...@@ -39,12 +46,13 @@ def kaiming_init(module, ...@@ -39,12 +46,13 @@ def kaiming_init(module,
bias=0, bias=0,
distribution='normal'): distribution='normal'):
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if distribution == 'uniform': if hasattr(module, 'weight') and module.weight is not None:
nn.init.kaiming_uniform_( if distribution == 'uniform':
module.weight, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(
else: module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
nn.init.kaiming_normal_( else:
module.weight, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
...@@ -57,10 +65,367 @@ def caffe2_xavier_init(module, bias=0): ...@@ -57,10 +65,367 @@ def caffe2_xavier_init(module, bias=0):
a=1, a=1,
mode='fan_in', mode='fan_in',
nonlinearity='leaky_relu', nonlinearity='leaky_relu',
bias=bias,
distribution='uniform') distribution='uniform')
def bias_init_with_prob(prior_prob): def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to giving probablity.""" """initialize conv/fc bias value according to giving probability."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob)) bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init return bias_init
class BaseInit(object):
def __init__(self, bias, bias_prob, layer):
if not isinstance(bias, (int, float)):
raise TypeError(f'bias must be a numbel, but got a {type(bias)}')
if bias_prob is not None:
if not isinstance(bias_prob, float):
raise TypeError(f'bias_prob type must be float, \
but got {type(bias_prob)}')
if layer is not None:
if not isinstance(layer, (str, list)):
raise TypeError(f'layer must be str or list[str], \
but got a {type(layer)}')
if bias_prob is not None:
self.bias = bias_init_with_prob(bias_prob)
else:
self.bias = bias
self.layer = [layer] if isinstance(layer, str) else layer
@INITIALIZERS.register_module(name='Constant')
class ConstantInit(BaseInit):
"""Initialize module parameters with constant values.
Args:
val (int | float): the value to fill the weights in the module with
bias (int | float): the value to fill the bias or
define initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, val, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
self.val = val
def __call__(self, module):
def init(m):
if self.layer is None:
constant_init(m, self.val, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
constant_init(m, self.val, self.bias)
module.apply(init)
@INITIALIZERS.register_module(name='Xavier')
class XavierInit(BaseInit):
r"""Initialize module parameters with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks - Glorot, X. & Bengio, Y. (2010).
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
Args:
gain (int | float): an optional scaling factor. Defaults to 1.
bias (int | float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'``
or ``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
gain=1,
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
self.gain = gain
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.layer is None:
xavier_init(m, self.gain, self.bias, self.distribution)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
xavier_init(m, self.gain, self.bias, self.distribution)
module.apply(init)
@INITIALIZERS.register_module(name='Normal')
class NormalInit(BaseInit):
r"""Initialize module parameters with the values drawn from the normal
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
mean (int | float):the mean of the normal distribution. Defaults to 0.
std (int | float): the standard deviation of the normal distribution.
Defaults to 1.
bias (int | float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, mean=0, std=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
self.mean = mean
self.std = std
def __call__(self, module):
def init(m):
if self.layer is None:
normal_init(m, self.mean, self.std, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
normal_init(m, self.mean, self.std, self.bias)
module.apply(init)
@INITIALIZERS.register_module(name='Uniform')
class UniformInit(BaseInit):
r"""Initialize module parameters with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (int | float): the lower bound of the uniform distribution.
Defaults to 0.
b (int | float): the upper bound of the uniform distribution.
Defaults to 1.
bias (int | float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self, a=0, b=1, bias=0, bias_prob=None, layer=None):
super().__init__(bias, bias_prob, layer)
self.a = a
self.b = b
def __call__(self, module):
def init(m):
if self.layer is None:
uniform_init(m, self.a, self.b, self.bias)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
uniform_init(m, self.a, self.b, self.bias)
module.apply(init)
@INITIALIZERS.register_module(name='Kaiming')
class KaimingInit(BaseInit):
r"""Initialize module paramters with the valuse according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification - He, K. et al. (2015).
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
Args:
a (int | float): the negative slope of the rectifier used after this
layer (only used with ``'leaky_relu'``). Defaults to 0.
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
``'fan_in'`` preserves the magnitude of the variance of the weights
in the forward pass. Choosing ``'fan_out'`` preserves the
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
nonlinearity (str): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
Defaults to 'relu'.
bias (int | float): the value to fill the bias or define
initialization type for bias. Defaults to 0.
bias_prob (float, optional): the probability for bias initialization.
Defaults to None.
distribution (str): distribution either be ``'normal'`` or
``'uniform'``. Defaults to ``'normal'``.
layer (str | list[str], optional): the layer will be initialized.
Defaults to None.
"""
def __init__(self,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
bias_prob=None,
distribution='normal',
layer=None):
super().__init__(bias, bias_prob, layer)
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
self.distribution = distribution
def __call__(self, module):
def init(m):
if self.layer is None:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
else:
layername = m.__class__.__name__
for layer_ in self.layer:
if layername == layer_:
kaiming_init(m, self.a, self.mode, self.nonlinearity,
self.bias, self.distribution)
module.apply(init)
@INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object):
"""Initialize module by loading a pretrained model
Args:
checkpoint (str): the file should be load
prefix (str, optional): the prefix to indicate the sub-module.
Defaults to None.
"""
def __init__(self, checkpoint, prefix=None, map_location=None):
self.checkpoint = checkpoint
self.prefix = prefix
self.map_location = map_location
def __call__(self, module):
from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict)
logger = get_logger('mmcv')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
def _initialize(module, cfg):
func = build_from_cfg(cfg, INITIALIZERS)
func(module)
def _initialize_override(module, override):
if not isinstance(override, (dict, list)):
raise TypeError(
f'override must be a dict or list, but got {type(override)}')
override = [override] if isinstance(override, dict) else override
for override_ in override:
name = override_.pop('name', None)
if hasattr(module, name):
_initialize(getattr(module, name), override_)
else:
raise RuntimeError(f'module did not have attribute {name}')
def initialize(module, init_cfg):
"""Initialize a module.
Args:
module (``torch.nn.Module``): the module will be initialized.
init_cfg (dict | list[dict]): initialization configuration dict to
define initializer. OpenMMLab has implemented 7 initializers
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
``Kaiming``, ``Pretrained`` and ``BiasProb`` for bias
initialization.
Example:
>>> module = nn.Linear(2, 3, bias=True)
>>> init_cfg = dict(type='Constant', val =1 , bias =2)
>>> initialize(module, init_cfg)
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
>>> # define key ``'layer'`` for initializing layer with different
>>> # configuration
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
dict(type='Constant', layer='Linear', val=2)]
>>> initialize(module, init_cfg)
>>> # Omitting ``'layer'`` initialize module with same configuration
>>> init_cfg = dict(type='Constant', val=1, bias=2)
>>> initialize(module, init_cfg)
>>> # define key``'override'`` to initialize some specific override in
>>> # module
>>> class FooNet(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.feat = nn.Conv2d(3, 16, 3)
>>> self.reg = nn.Conv2d(16, 10, 3)
>>> self.cls = nn.Conv2d(16, 5, 3)
>>> model = FooNet()
>>> init_cfg = dict(type='Constant', val=1, bias=2,
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
>>> initialize(model, init_cfg)
>>> model = ResNet(depth=50)
>>> # Initialize weights with the pretrained model.
>>> init_cfg = dict(type='PretrainedInit',
checkpoint='torchvision://resnet50')
>>> initialize(model, init_cfg)
>>> # Intialize weights of a sub-module with the specific part of
>>> # a pretrained model by using "prefix".
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
>>> 'retinanet_r50_fpn_1x_coco/'\
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
>>> init_cfg = dict(type='Pretrained',
checkpoint=url, prefix='backbone.')
"""
if not isinstance(init_cfg, (dict, list)):
raise TypeError(f'init_cfg must be a dict, but got {type(init_cfg)}')
if isinstance(init_cfg, dict):
init_cfg = [init_cfg]
for cfg in init_cfg:
override = cfg.pop('override', None)
_initialize(module, cfg)
if override is not None:
_initialize_override(module, override)
else:
# All attributes in module have same initialization.
pass
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
import torch.nn as nn import torch.nn as nn
from ..runner import load_checkpoint
from .utils import constant_init, kaiming_init, normal_init from .utils import constant_init, kaiming_init, normal_init
...@@ -126,6 +125,7 @@ class VGG(nn.Module): ...@@ -126,6 +125,7 @@ class VGG(nn.Module):
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint
load_checkpoint(self, pretrained, strict=False, logger=logger) load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None: elif pretrained is None:
for m in self.modules(): for m in self.modules():
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .base_module import BaseModule
from .base_runner import BaseRunner from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint, load_checkpoint, from .checkpoint import (CheckpointLoader, _load_checkpoint,
_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu) load_state_dict, save_checkpoint, weights_to_cpu)
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info, from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only) init_dist, master_only)
...@@ -34,5 +36,5 @@ __all__ = [ ...@@ -34,5 +36,5 @@ __all__ = [
'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model',
'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner',
'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler',
'CheckpointLoader' 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix'
] ]
# Copyright (c) Open-MMLab. All rights reserved.
import warnings
from abc import ABCMeta
import torch.nn as nn
class BaseModule(nn.Module, metaclass=ABCMeta):
"""Base module for all modules in openmmlab."""
def __init__(self, init_cfg=None):
"""Initialize BaseModule, inherited from `torch.nn.Module`
Args:
init_cfg (dict, optional): Initialization config dict.
"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
# define default value of init_cfg instead of hard code
# in init_weigt() function
self._is_init = False
if init_cfg is not None:
self.init_cfg = init_cfg
# Backward compatibility in derived classes
# if pretrained is not None:
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
# key, please consider using init_cfg')
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self):
return self._is_init
def init_weight(self):
"""Initialize the weights."""
from ..cnn import initialize
if not self._is_init:
if hasattr(self, 'init_cfg'):
initialize(self, self.init_cfg)
self._is_init = True
for module in self.children():
if 'init_weight' in dir(module):
module.init_weight()
else:
warnings.warn('This module has bee initialized, \
please call initialize(module, init_cfg) to reinitialize it')
...@@ -464,6 +464,39 @@ def _load_checkpoint(filename, map_location=None, logger=None): ...@@ -464,6 +464,39 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return CheckpointLoader.load_checkpoint(filename, map_location, logger) return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
"""Load partial pretrained model with specific prefix.
Args:
prefix (str): The prefix of sub-module.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location=map_location)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if not prefix.endswith('.'):
prefix += '.'
prefix_len = len(prefix)
state_dict = {
k[prefix_len:]: v
for k, v in state_dict.items() if k.startswith(prefix)
}
assert state_dict, f'{prefix} is not in the pretrained model'
return state_dict
def load_checkpoint(model, def load_checkpoint(model,
filename, filename,
map_location=None, map_location=None,
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from tempfile import TemporaryDirectory
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init, from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
UniformInit, XavierInit, bias_init_with_prob,
caffe2_xavier_init, constant_init, initialize,
kaiming_init, normal_init, uniform_init, xavier_init) kaiming_init, normal_init, uniform_init, xavier_init)
...@@ -75,3 +79,283 @@ def test_bias_init_with_prob(): ...@@ -75,3 +79,283 @@ def test_bias_init_with_prob():
# TODO: sanity check of weight distribution, e.g. mean, std # TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob)) bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias)) assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
def test_constaninit():
"""test ConstantInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = ConstantInit(val=1, bias=2, layer='Conv2d')
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 1.))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = ConstantInit(val=3, bias_prob=0.01, layer='Linear')
func(model)
res = bias_init_with_prob(0.01)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
func = ConstantInit(val=4, bias=5)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 4.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 4.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 5.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 5.))
# test bias input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias='1')
# test bias_prob type
with pytest.raises(TypeError):
func = ConstantInit(val=1, bias_prob='1')
# test layer input type
with pytest.raises(TypeError):
func = ConstantInit(val=1, layer=1)
def test_xavierinit():
"""test XavierInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert model[0].bias.allclose(torch.full_like(model[2].bias, 0.1))
assert not model[2].bias.allclose(torch.full_like(model[0].bias, 0.1))
constant_func = ConstantInit(val=0, bias=0)
func = XavierInit(gain=100, bias_prob=0.01)
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
res = bias_init_with_prob(0.01)
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, res))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, res))
# test bias input type
with pytest.raises(TypeError):
func = XavierInit(bias='0.1', layer='Conv2d')
# test layer inpur type
with pytest.raises(TypeError):
func = XavierInit(bias=0.1, layer=1)
def test_normalinit():
"""test Normalinit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = NormalInit(mean=100, std=1e-5, bias=200)
func(model)
assert model[0].weight.allclose(torch.tensor(100.))
assert model[2].weight.allclose(torch.tensor(100.))
assert model[0].bias.allclose(torch.tensor(200.))
assert model[2].bias.allclose(torch.tensor(200.))
func = NormalInit(
mean=300, std=1e-5, bias_prob=0.01, layer=['Conv2d', 'Linear'])
res = bias_init_with_prob(0.01)
func(model)
assert model[0].weight.allclose(torch.tensor(300.))
assert model[2].weight.allclose(torch.tensor(300.))
assert model[0].bias.allclose(torch.tensor(res))
assert model[2].bias.allclose(torch.tensor(res))
def test_uniforminit():
""""test UniformInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = UniformInit(a=1, b=1, bias=2)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
func = UniformInit(a=100, b=100, layer=['Conv2d', 'Linear'], bias=10)
func(model)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape,
100.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape,
100.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_kaiminginit():
"""test KaimingInit class."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = KaimingInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
func = KaimingInit(a=100, bias=10)
constant_func = ConstantInit(val=0, bias=0)
model.apply(constant_func)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 0.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.))
func(model)
assert not torch.equal(model[0].weight,
torch.full(model[0].weight.shape, 0.))
assert not torch.equal(model[2].weight,
torch.full(model[2].weight.shape, 0.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 10.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
def test_pretrainedinit():
"""test PretrainedInit class."""
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
modelA.apply(constant_func)
modelB = FooModule()
funcB = PretrainedInit(checkpoint='modelA.pth')
modelC = nn.Linear(1, 2)
funcC = PretrainedInit(checkpoint='modelA.pth', prefix='linear.')
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
funcB(modelB)
assert torch.equal(modelB.linear.weight,
torch.full(modelB.linear.weight.shape, 1.))
assert torch.equal(modelB.linear.bias,
torch.full(modelB.linear.bias.shape, 2.))
assert torch.equal(modelB.conv2d.weight,
torch.full(modelB.conv2d.weight.shape, 1.))
assert torch.equal(modelB.conv2d.bias,
torch.full(modelB.conv2d.bias.shape, 2.))
assert torch.equal(modelB.conv2d_2.weight,
torch.full(modelB.conv2d_2.weight.shape, 1.))
assert torch.equal(modelB.conv2d_2.bias,
torch.full(modelB.conv2d_2.bias.shape, 2.))
funcC(modelC)
assert torch.equal(modelC.weight, torch.full(modelC.weight.shape, 1.))
assert torch.equal(modelC.bias, torch.full(modelC.bias.shape, 2.))
def test_initialize():
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
foonet = FooModule()
init_cfg = dict(type='Constant', val=1, bias=2)
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 1.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 2.))
init_cfg = [
dict(type='Constant', layer='Conv1d', val=1, bias=2),
dict(type='Constant', layer='Linear', val=3, bias=4)
]
initialize(model, init_cfg)
assert torch.equal(model[0].weight, torch.full(model[0].weight.shape, 1.))
assert torch.equal(model[2].weight, torch.full(model[2].weight.shape, 3.))
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 2.))
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 4.))
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
init_cfg = dict(
type='Pretrained',
checkpoint='modelA.pth',
override=dict(type='Constant', name='conv2d_2', val=3, bias=4))
modelA = FooModule()
constant_func = ConstantInit(val=1, bias=2)
modelA.apply(constant_func)
with TemporaryDirectory():
torch.save(modelA.state_dict(), 'modelA.pth')
initialize(foonet, init_cfg)
assert torch.equal(foonet.linear.weight,
torch.full(foonet.linear.weight.shape, 1.))
assert torch.equal(foonet.linear.bias,
torch.full(foonet.linear.bias.shape, 2.))
assert torch.equal(foonet.conv2d.weight,
torch.full(foonet.conv2d.weight.shape, 1.))
assert torch.equal(foonet.conv2d.bias,
torch.full(foonet.conv2d.bias.shape, 2.))
assert torch.equal(foonet.conv2d_2.weight,
torch.full(foonet.conv2d_2.weight.shape, 3.))
assert torch.equal(foonet.conv2d_2.bias,
torch.full(foonet.conv2d_2.bias.shape, 4.))
# test init_cfg type
with pytest.raises(TypeError):
init_cfg = 'init_cfg'
initialize(foonet, init_cfg)
# test override value type
with pytest.raises(TypeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override='conv')
initialize(foonet, init_cfg)
# test override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=dict(type='Constant', name='conv2d_3', val=3, bias=4))
initialize(foonet, init_cfg)
# test list override name
with pytest.raises(RuntimeError):
init_cfg = dict(
type='Constant',
val=1,
bias=2,
layer=['Conv2d', 'Linear'],
override=[
dict(type='Constant', name='conv2d', val=3, bias=4),
dict(type='Constant', name='conv2d_3', val=5, bias=6)
])
initialize(foonet, init_cfg)
import torch
from torch import nn
from mmcv.runner import BaseModule
from mmcv.utils import Registry, build_from_cfg
COMPONENTS = Registry('component')
FOOMODELS = Registry('model')
@COMPONENTS.register_module()
class FooConv1d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv1d = nn.Conv1d(4, 1, 4)
def forward(self, x):
return self.conv1d(x)
@COMPONENTS.register_module()
class FooConv2d(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.conv2d = nn.Conv2d(3, 1, 3)
def forward(self, x):
return self.conv2d(x)
@COMPONENTS.register_module()
class FooLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(3, 4)
def forward(self, x):
return self.linear(x)
@COMPONENTS.register_module()
class FooLinearConv1d(BaseModule):
def __init__(self, linear=None, conv1d=None, init_cfg=None):
super().__init__(init_cfg)
if linear is not None:
self.linear = build_from_cfg(linear, COMPONENTS)
if conv1d is not None:
self.conv1d = build_from_cfg(conv1d, COMPONENTS)
def forward(self, x):
x = self.linear(x)
return self.conv1d(x)
@FOOMODELS.register_module()
class FooModel(BaseModule):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(init_cfg)
if component1 is not None:
self.component1 = build_from_cfg(component1, COMPONENTS)
if component2 is not None:
self.component2 = build_from_cfg(component2, COMPONENTS)
if component3 is not None:
self.component3 = build_from_cfg(component3, COMPONENTS)
if component4 is not None:
self.component4 = build_from_cfg(component4, COMPONENTS)
# its type is not BaseModule, it can be initialized
# with "override" key.
self.reg = nn.Linear(3, 4)
def test_model_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d)
├──component2 (FooConv2d)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d)
├──linear (FooLinear)
├──conv1d (FooConv1d)
├──reg (nn.Linear)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=3, bias=4)
├──component2 (FooConv2d, weight=5, bias=6)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=1, bias=2)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(type='Constant', val=1, bias=2, layer='Linear'),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d')
],
component1=dict(type='FooConv1d'),
component2=dict(type='FooConv2d'),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 3.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 4.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 5.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 6.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight, torch.full(model.reg.weight.shape,
1.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 2.0))
def test_nest_components_weight_init():
"""
Config
model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
Conv2d: weight=5, bias=6)
├──component1 (FooConv1d, Conv1d: weight=7, bias=8)
├──component2 (FooConv2d, Conv2d: weight=9, bias=10)
├──component3 (FooLinear)
├──component4 (FooLinearConv1d, Linear: weight=11, bias=12)
├──linear (FooLinear, Linear: weight=11, bias=12)
├──conv1d (FooConv1d)
├──reg (nn.Linear, weight=13, bias=14)
Parameters after initialization
model (FooModel)
├──component1 (FooConv1d, weight=7, bias=8)
├──component2 (FooConv2d, weight=9, bias=10)
├──component3 (FooLinear, weight=1, bias=2)
├──component4 (FooLinearConv1d)
├──linear (FooLinear, weight=1, bias=2)
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=13, bias=14)
"""
model_cfg = dict(
type='FooModel',
init_cfg=[
dict(
type='Constant',
val=1,
bias=2,
layer='Linear',
override=dict(type='Constant', name='reg', val=13, bias=14)),
dict(type='Constant', val=3, bias=4, layer='Conv1d'),
dict(type='Constant', val=5, bias=6, layer='Conv2d'),
],
component1=dict(
type='FooConv1d', init_cfg=dict(type='Constant', val=7, bias=8)),
component2=dict(
type='FooConv2d', init_cfg=dict(type='Constant', val=9, bias=10)),
component3=dict(type='FooLinear'),
component4=dict(
type='FooLinearConv1d',
linear=dict(type='FooLinear'),
conv1d=dict(type='FooConv1d')))
model = build_from_cfg(model_cfg, FOOMODELS)
model.init_weight()
assert torch.equal(model.component1.conv1d.weight,
torch.full(model.component1.conv1d.weight.shape, 7.0))
assert torch.equal(model.component1.conv1d.bias,
torch.full(model.component1.conv1d.bias.shape, 8.0))
assert torch.equal(model.component2.conv2d.weight,
torch.full(model.component2.conv2d.weight.shape, 9.0))
assert torch.equal(model.component2.conv2d.bias,
torch.full(model.component2.conv2d.bias.shape, 10.0))
assert torch.equal(model.component3.linear.weight,
torch.full(model.component3.linear.weight.shape, 1.0))
assert torch.equal(model.component3.linear.bias,
torch.full(model.component3.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.linear.linear.weight,
torch.full(model.component4.linear.linear.weight.shape, 1.0))
assert torch.equal(
model.component4.linear.linear.bias,
torch.full(model.component4.linear.linear.bias.shape, 2.0))
assert torch.equal(
model.component4.conv1d.conv1d.weight,
torch.full(model.component4.conv1d.conv1d.weight.shape, 3.0))
assert torch.equal(
model.component4.conv1d.conv1d.bias,
torch.full(model.component4.conv1d.conv1d.bias.shape, 4.0))
assert torch.equal(model.reg.weight,
torch.full(model.reg.weight.shape, 13.0))
assert torch.equal(model.reg.bias, torch.full(model.reg.bias.shape, 14.0))
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -8,7 +9,8 @@ import torch.nn as nn ...@@ -8,7 +9,8 @@ import torch.nn as nn
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict, load_from_pavi from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_from_pavi)
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -138,6 +140,7 @@ def test_get_state_dict(): ...@@ -138,6 +140,7 @@ def test_get_state_dict():
def test_load_pavimodel_dist(): def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock() sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock() sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel() pavimodel = Mockpavimodel()
...@@ -152,10 +155,45 @@ def test_load_pavimodel_dist(): ...@@ -152,10 +155,45 @@ def test_load_pavimodel_dist():
_ = load_from_pavi('pavi://checkpoint.pth') _ = load_from_pavi('pavi://checkpoint.pth')
def test_load_checkpoint_with_prefix():
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
model = FooModule()
nn.init.constant_(model.linear.weight, 1)
nn.init.constant_(model.linear.bias, 2)
nn.init.constant_(model.conv2d.weight, 3)
nn.init.constant_(model.conv2d.bias, 4)
nn.init.constant_(model.conv2d_2.weight, 5)
nn.init.constant_(model.conv2d_2.bias, 6)
with TemporaryDirectory():
torch.save(model.state_dict(), 'model.pth')
prefix = 'conv2d'
state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth')
assert torch.equal(model.conv2d.state_dict()['weight'],
state_dict['weight'])
assert torch.equal(model.conv2d.state_dict()['bias'],
state_dict['bias'])
# test whether prefix is in pretrained model
with pytest.raises(AssertionError):
prefix = 'back'
_load_checkpoint_with_prefix(prefix, 'model.pth')
def test_load_classes_name(): def test_load_classes_name():
from mmcv.runner import load_checkpoint, save_checkpoint
import tempfile
import os import os
import tempfile
from mmcv.runner import load_checkpoint, save_checkpoint
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model() model = Model()
save_checkpoint(model, checkpoint_path) save_checkpoint(model, checkpoint_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment