Unverified Commit 5f5e8e83 authored by Miao Zheng's avatar Miao Zheng Committed by GitHub
Browse files

[Refactoring] Add Caffe2Xavier Initializer (#902)

* [Refactoring] Add Caffe2Xavier Initializer

* fix lint
parent 933b052d
...@@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -13,8 +13,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable # yapf: enable
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
PretrainedInit, UniformInit, XavierInit, NormalInit, PretrainedInit, UniformInit, XavierInit,
bias_init_with_prob, caffe2_xavier_init, constant_init, bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, initialize, fuse_conv_bn, get_model_complexity_info, initialize,
kaiming_init, normal_init, uniform_init, xavier_init) kaiming_init, normal_init, uniform_init, xavier_init)
...@@ -33,5 +33,6 @@ __all__ = [ ...@@ -33,5 +33,6 @@ __all__ = [
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit', 'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit' 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .flops_counter import get_model_complexity_info from .flops_counter import get_model_complexity_info
from .fuse_conv_bn import fuse_conv_bn from .fuse_conv_bn import fuse_conv_bn
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit, from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
PretrainedInit, UniformInit, XavierInit, KaimingInit, NormalInit, PretrainedInit, UniformInit,
bias_init_with_prob, caffe2_xavier_init, XavierInit, bias_init_with_prob, caffe2_xavier_init,
constant_init, initialize, kaiming_init, normal_init, constant_init, initialize, kaiming_init, normal_init,
uniform_init, xavier_init) uniform_init, xavier_init)
...@@ -12,5 +12,5 @@ __all__ = [ ...@@ -12,5 +12,5 @@ __all__ = [
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init', 'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS', 'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
'PretrainedInit' 'PretrainedInit', 'Caffe2XavierInit'
] ]
...@@ -298,6 +298,22 @@ class KaimingInit(BaseInit): ...@@ -298,6 +298,22 @@ class KaimingInit(BaseInit):
module.apply(init) module.apply(init)
@INITIALIZERS.register_module(name='Caffe2Xavier')
class Caffe2XavierInit(KaimingInit):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
def __init__(self, **kwargs):
super().__init__(
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform',
**kwargs)
def __call__(self, module):
super().__call__(module)
@INITIALIZERS.register_module(name='Pretrained') @INITIALIZERS.register_module(name='Pretrained')
class PretrainedInit(object): class PretrainedInit(object):
"""Initialize module by loading a pretrained model. """Initialize module by loading a pretrained model.
......
...@@ -6,10 +6,11 @@ import pytest ...@@ -6,10 +6,11 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit, from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
UniformInit, XavierInit, bias_init_with_prob, PretrainedInit, UniformInit, XavierInit,
caffe2_xavier_init, constant_init, initialize, bias_init_with_prob, caffe2_xavier_init, constant_init,
kaiming_init, normal_init, uniform_init, xavier_init) initialize, kaiming_init, normal_init, uniform_init,
xavier_init)
def test_constant_init(): def test_constant_init():
...@@ -219,6 +220,15 @@ def test_kaiminginit(): ...@@ -219,6 +220,15 @@ def test_kaiminginit():
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.)) assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
def test_caffe2xavierinit():
"""test Caffe2XavierInit."""
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
func(model)
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
class FooModule(nn.Module): class FooModule(nn.Module):
def __init__(self): def __init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment