Unverified Commit 9fd937ea authored by Eldar Kurtic's avatar Eldar Kurtic Committed by GitHub
Browse files

Replace BertLayerNorm with LayerNorm (#14385)

Running Movement pruning experiments with the newest HuggingFace would crash due to non-existing BertLayerNorm.
parent a67d47b4
...@@ -30,7 +30,7 @@ from emmental import MaskedBertConfig ...@@ -30,7 +30,7 @@ from emmental import MaskedBertConfig
from emmental.modules import MaskedLinear from emmental.modules import MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from transformers.models.bert.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert from transformers.models.bert.modeling_bert import ACT2FN, load_tf_weights_in_bert
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,7 +47,7 @@ class BertEmbeddings(nn.Module): ...@@ -47,7 +47,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
...@@ -182,7 +182,7 @@ class BertSelfOutput(nn.Module): ...@@ -182,7 +182,7 @@ class BertSelfOutput(nn.Module):
mask_init=config.mask_init, mask_init=config.mask_init,
mask_scale=config.mask_scale, mask_scale=config.mask_scale,
) )
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, threshold): def forward(self, hidden_states, input_tensor, threshold):
...@@ -275,7 +275,7 @@ class BertOutput(nn.Module): ...@@ -275,7 +275,7 @@ class BertOutput(nn.Module):
mask_init=config.mask_init, mask_init=config.mask_init,
mask_scale=config.mask_scale, mask_scale=config.mask_scale,
) )
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, threshold): def forward(self, hidden_states, input_tensor, threshold):
...@@ -398,7 +398,7 @@ class MaskedBertPreTrainedModel(PreTrainedModel): ...@@ -398,7 +398,7 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment