Commit 1d438f15 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[XLNet] Use pytorch's layernorm like in BERT

See #1089

cc @thomwolf @lysandrejik

Also @dhpollack
parent 574c5b3a
...@@ -337,20 +337,7 @@ try: ...@@ -337,20 +337,7 @@ try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .") logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
class XLNetLayerNorm(nn.Module): from torch.nn import LayerNorm as XLNetLayerNorm
def __init__(self, d_model, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(XLNetLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class XLNetRelativeAttention(nn.Module): class XLNetRelativeAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment