"...git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "7013e5e2d779d68d2a9cdafa96f8fdb2645618a5"
Commit 3333bab6 authored by Kai Chen's avatar Kai Chen
Browse files

use weight init methods

parent 62438226
from .alexnet import AlexNet
from .vgg import VGG, make_vgg_layer
from .resnet import ResNet, make_res_layer
from .weight_init import xavier_init, normal_init, uniform_init, kaiming_init
from .weight_init import (constant_init, xavier_init, normal_init,
uniform_init, kaiming_init)
__all__ = [
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
'xavier_init', 'normal_init', 'uniform_init', 'kaiming_init'
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init'
]
......@@ -4,6 +4,7 @@ import math
import torch.nn as nn
import torch.utils.checkpoint as cp
from .weight_init import constant_init, kaiming_init
from ..runner import load_checkpoint
......@@ -268,11 +269,9 @@ class ResNet(nn.Module):
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
......
......@@ -2,6 +2,7 @@ import logging
import torch.nn as nn
from .weight_init import constant_init, normal_init, kaiming_init
from ..runner import load_checkpoint
......@@ -112,16 +113,11 @@ class VGG(nn.Module):
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
constant_init(m, 1)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
......
import torch.nn as nn
def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment