Commit 993f926a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add no-grad block

parent ce27a6ca
...@@ -149,26 +149,26 @@ class Linear(nn.Linear): ...@@ -149,26 +149,26 @@ class Linear(nn.Linear):
with torch.no_grad(): with torch.no_grad():
self.bias.fill_(0) self.bias.fill_(0)
if init_fn is not None: with torch.no_grad():
init_fn(self.weight, self.bias) if init_fn is not None:
else: init_fn(self.weight, self.bias)
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else: else:
raise ValueError("Invalid init string.") if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment