Unverified Commit 528bf72c authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #61 from hellock/weight-init

Add caffe2_xavier_init()
parents 944da49f 5e3d09bc
...@@ -2,10 +2,10 @@ from .alexnet import AlexNet ...@@ -2,10 +2,10 @@ from .alexnet import AlexNet
from .vgg import VGG, make_vgg_layer from .vgg import VGG, make_vgg_layer
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .weight_init import (constant_init, xavier_init, normal_init, from .weight_init import (constant_init, xavier_init, normal_init,
uniform_init, kaiming_init) uniform_init, kaiming_init, caffe2_xavier_init)
__all__ = [ __all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'uniform_init', 'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init' 'kaiming_init', 'caffe2_xavier_init'
] ]
...@@ -30,6 +30,7 @@ def uniform_init(module, a=0, b=1, bias=0): ...@@ -30,6 +30,7 @@ def uniform_init(module, a=0, b=1, bias=0):
def kaiming_init(module, def kaiming_init(module,
a=0,
mode='fan_out', mode='fan_out',
nonlinearity='relu', nonlinearity='relu',
bias=0, bias=0,
...@@ -37,9 +38,20 @@ def kaiming_init(module, ...@@ -37,9 +38,20 @@ def kaiming_init(module,
assert distribution in ['uniform', 'normal'] assert distribution in ['uniform', 'normal']
if distribution == 'uniform': if distribution == 'uniform':
nn.init.kaiming_uniform_( nn.init.kaiming_uniform_(
module.weight, mode=mode, nonlinearity=nonlinearity) module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else: else:
nn.init.kaiming_normal_( nn.init.kaiming_normal_(
module.weight, mode=mode, nonlinearity=nonlinearity) module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None: if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment