Unverified Commit 47ca0eaa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (#9386)

parent 75ff5305
......@@ -22,7 +22,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN
from ...file_utils import (
......@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class BartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
......@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module):
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
"""
......@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = BartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(
self,
......@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel):
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
self.init_weights()
......@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel):
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
self.init_weights()
......
......@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN
from ...file_utils import (
......@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r"""
"""
have_fused_layer_norm = False
try:
from apex.normalization import FusedLayerNorm
have_fused_layer_norm = True
except ImportError:
pass
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
......
......@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import LayerNorm
from ...activations import ACT2FN
from ...file_utils import (
......@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class ProphetNetPreTrainedModel(PreTrainedModel):
config_class = ProphetNetConfig
base_model_prefix = "prophetnet"
......@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask):
# 1st residual block
......@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module):
super().__init__()
# 1st residual block
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block
if config.add_cross_attention:
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
# 3rd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(
self,
......@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
)
self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
......@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size)
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.init_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment