"tests/vscode:/vscode.git/clone" did not exist on "3fc10ded000d673768bf03195b0600f69af96a50"
Unverified Commit 41f35d0b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1089 from dhpollack/dhp/use_pytorch_layernorm

change layernorm code to pytorch's native layer norm
parents 01ad55f8 e13465fb
......@@ -224,20 +224,7 @@ try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except (ImportError, AttributeError) as e:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment