"vscode:/vscode.git/clone" did not exist on "97ef6ff8b3ad0b485173a99a7b1960536bc22437"
Unverified Commit fa84b264 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add bias_init_with_prob (#213)

* add bia_init_with_prob

* add test_weight_init
parent 12e5913b
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
from .alexnet import AlexNet from .alexnet import AlexNet
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .vgg import VGG, make_vgg_layer from .vgg import VGG, make_vgg_layer
from .weight_init import (caffe2_xavier_init, constant_init, kaiming_init, from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
normal_init, uniform_init, xavier_init) constant_init, kaiming_init, normal_init,
uniform_init, xavier_init)
__all__ = [ __all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer', 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'constant_init', 'xavier_init', 'normal_init', 'uniform_init', 'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init' 'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob'
] ]
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import torch.nn as nn import torch.nn as nn
...@@ -57,3 +58,9 @@ def caffe2_xavier_init(module, bias=0): ...@@ -57,3 +58,9 @@ def caffe2_xavier_init(module, bias=0):
mode='fan_in', mode='fan_in',
nonlinearity='leaky_relu', nonlinearity='leaky_relu',
distribution='uniform') distribution='uniform')
def bias_init_with_prob(prior_prob):
""" initialize conv/fc bias value according to giving probablity"""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init
# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import pytest
import torch
from torch import nn
from mmcv.cnn import (bias_init_with_prob, caffe2_xavier_init, constant_init,
kaiming_init, normal_init, uniform_init, xavier_init)
def test_constant_init():
conv_module = nn.Conv2d(3, 16, 3)
constant_init(conv_module, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
assert conv_module.bias.allclose(torch.zeros_like(conv_module.bias))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
constant_init(conv_module_no_bias, 0.1)
assert conv_module.weight.allclose(
torch.full_like(conv_module.weight, 0.1))
def test_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
xavier_init(conv_module, bias=0.1)
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
xavier_init(conv_module, distribution='uniform')
# TODO: sanity check of weight distribution, e.g. mean, std
with pytest.raises(AssertionError):
xavier_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
xavier_init(conv_module_no_bias)
def test_normal_init():
conv_module = nn.Conv2d(3, 16, 3)
normal_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
normal_init(conv_module_no_bias)
# TODO: sanity check distribution, e.g. mean, std
def test_uniform_init():
conv_module = nn.Conv2d(3, 16, 3)
uniform_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
uniform_init(conv_module_no_bias)
def test_kaiming_init():
conv_module = nn.Conv2d(3, 16, 3)
kaiming_init(conv_module, bias=0.1)
# TODO: sanity check of weight distribution, e.g. mean, std
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, 0.1))
kaiming_init(conv_module, distribution='uniform')
with pytest.raises(AssertionError):
kaiming_init(conv_module, distribution='student-t')
conv_module_no_bias = nn.Conv2d(3, 16, 3, bias=False)
kaiming_init(conv_module_no_bias)
def test_caffe_xavier_init():
conv_module = nn.Conv2d(3, 16, 3)
caffe2_xavier_init(conv_module)
def test_bias_init_with_prob():
conv_module = nn.Conv2d(3, 16, 3)
prior_prob = 0.1
normal_init(conv_module, bias=bias_init_with_prob(0.1))
# TODO: sanity check of weight distribution, e.g. mean, std
bias = float(-np.log((1 - prior_prob) / prior_prob))
assert conv_module.bias.allclose(torch.full_like(conv_module.bias, bias))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment