Commit 3e472b22 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

rm BertLayerNorm

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/608

Differential Revision: D15541220

Pulled By: myleott

fbshipit-source-id: 52a8e4da72cc6e3e25cf98c989d34a269d614c9d
parent ed592ab5
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
from .adaptive_input import AdaptiveInput from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .bert_layer_norm import BertLayerNorm
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
...@@ -34,7 +33,6 @@ __all__ = [ ...@@ -34,7 +33,6 @@ __all__ = [
'AdaptiveInput', 'AdaptiveInput',
'AdaptiveSoftmax', 'AdaptiveSoftmax',
'BeamableMM', 'BeamableMM',
'BertLayerNorm',
'CharacterTokenEmbedder', 'CharacterTokenEmbedder',
'ConvTBC', 'ConvTBC',
'DownsampledMultiHeadAttention', 'DownsampledMultiHeadAttention',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""
Construct a layernorm module in the TF style used with BERT
(epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment