Commit 9c35c132 authored by thomwolf's avatar thomwolf
Browse files

apex LayerNorm

parent b9c77b98
...@@ -217,7 +217,7 @@ class PositionwiseFF(nn.Module): ...@@ -217,7 +217,7 @@ class PositionwiseFF(nn.Module):
nn.Dropout(dropout), nn.Dropout(dropout),
) )
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = LayerNorm(d_model)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -254,7 +254,7 @@ class MultiHeadAttn(nn.Module): ...@@ -254,7 +254,7 @@ class MultiHeadAttn(nn.Module):
self.dropatt = nn.Dropout(dropatt) self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5) self.scale = 1 / (d_head ** 0.5)
...@@ -335,7 +335,7 @@ class RelMultiHeadAttn(nn.Module): ...@@ -335,7 +335,7 @@ class RelMultiHeadAttn(nn.Module):
self.dropatt = nn.Dropout(dropatt) self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5) self.scale = 1 / (d_head ** 0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment