Unverified Commit 47ca0eaa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (#9386)

parent 75ff5305
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] ...@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class BartLearnedPositionalEmbedding(nn.Embedding): class BartLearnedPositionalEmbedding(nn.Embedding):
""" """
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
...@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module): ...@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module):
dropout=config.attention_dropout, dropout=config.attention_dropout,
) )
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False): def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False):
""" """
...@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module): ...@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module):
self.activation_dropout = config.activation_dropout self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = BartAttention( self.encoder_attn = BartAttention(
self.embed_dim, self.embed_dim,
config.decoder_attention_heads, config.decoder_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=True, is_decoder=True,
) )
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward( def forward(
self, self,
...@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel): ...@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel):
config.extra_pos_embeddings, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm # mbart has one extra layer_norm
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
self.init_weights() self.init_weights()
...@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel): ...@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel):
config.extra_pos_embeddings, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
self.init_weights() self.init_weights()
......
...@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r""" ...@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r"""
""" """
have_fused_layer_norm = False
try:
from apex.normalization import FusedLayerNorm
have_fused_layer_norm = True
except ImportError:
pass
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
def invert_mask(attention_mask): def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False""" """Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2 assert attention_mask.dim() == 2
......
...@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple ...@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput): ...@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class ProphetNetPreTrainedModel(PreTrainedModel): class ProphetNetPreTrainedModel(PreTrainedModel):
config_class = ProphetNetConfig config_class = ProphetNetConfig
base_model_prefix = "prophetnet" base_model_prefix = "prophetnet"
...@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module): ...@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module):
super().__init__() super().__init__()
# 1st residual block # 1st residual block
self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads) self.self_attn = ProphetNetSelfAttention(config, config.num_encoder_attention_heads)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block # 2nd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward = ProhpetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
# 1st residual block # 1st residual block
...@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module): ...@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module):
super().__init__() super().__init__()
# 1st residual block # 1st residual block
self.self_attn = ProphetNetNgramProphetNetSelfAttention(config) self.self_attn = ProphetNetNgramProphetNetSelfAttention(config)
self.self_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.self_attn_layer_norm = LayerNorm(config.hidden_size)
# 2nd residual block # 2nd residual block
if config.add_cross_attention: if config.add_cross_attention:
self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads) self.cross_attn = ProphetNetSelfAttention(config, config.num_decoder_attention_heads)
self.cross_attn_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
# 3rd residual block # 3rd residual block
self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim) self.feed_forward = ProhpetNetFeedForward(config, config.decoder_ffn_dim)
self.feed_forward_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
def forward( def forward(
self, self,
...@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
) )
self.position_embeddings = ProhpetNetPositionalEmbeddings(config) self.position_embeddings = ProhpetNetPositionalEmbeddings(config)
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
...@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
self.embeddings_layer_norm = ProphetNetLayerNorm(config.hidden_size) self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.init_weights() self.init_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment