"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7c3a75a1ce7bdae1632dbcfd0bf2fda349a6fa36"
Commit 9c38b35f authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for checking bias

parent ab670420
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
def constant_init(module, val, bias=0): def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val) nn.init.constant_(module.weight, val)
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
...@@ -13,19 +13,19 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'): ...@@ -13,19 +13,19 @@ def xavier_init(module, gain=1, bias=0, distribution='normal'):
nn.init.xavier_uniform_(module.weight, gain=gain) nn.init.xavier_uniform_(module.weight, gain=gain)
else: else:
nn.init.xavier_normal_(module.weight, gain=gain) nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0): def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std) nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0): def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b) nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
...@@ -41,5 +41,5 @@ def kaiming_init(module, ...@@ -41,5 +41,5 @@ def kaiming_init(module,
else: else:
nn.init.kaiming_normal_( nn.init.kaiming_normal_(
module.weight, mode=mode, nonlinearity=nonlinearity) module.weight, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias) nn.init.constant_(module.bias, bias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment