Unverified Commit 7da995c0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix embeddings for PyTorch 1.8 (#10549)

* Fix embeddings for PyTorch 1.8

* Try with PyTorch 1.8.0

* Fix embeddings init

* Fix copies

* Typo

* More typos
parent 3e056c10
...@@ -51,13 +51,17 @@ class RetriBertPreTrainedModel(PreTrainedModel): ...@@ -51,13 +51,17 @@ class RetriBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
RETRIBERT_START_DOCSTRING = r""" RETRIBERT_START_DOCSTRING = r"""
......
...@@ -574,15 +574,19 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -574,15 +574,19 @@ class RobertaPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
ROBERTA_START_DOCSTRING = r""" ROBERTA_START_DOCSTRING = r"""
......
...@@ -432,15 +432,19 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): ...@@ -432,15 +432,19 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, SqueezeBertLayerNorm): elif isinstance(module, SqueezeBertLayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
SQUEEZEBERT_START_DOCSTRING = r""" SQUEEZEBERT_START_DOCSTRING = r"""
......
...@@ -700,15 +700,19 @@ class TapasPreTrainedModel(PreTrainedModel): ...@@ -700,15 +700,19 @@ class TapasPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
TAPAS_START_DOCSTRING = r""" TAPAS_START_DOCSTRING = r"""
......
...@@ -254,10 +254,12 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -254,10 +254,12 @@ class XLMPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None: if self.config is not None and self.config.embed_init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if self.config is not None and self.config.init_std is not None: if self.config is not None and self.config.init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.init_std) nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
if hasattr(module, "bias") and module.bias is not None: if module.bias is not None:
nn.init.constant_(module.bias, 0.0) nn.init.constant_(module.bias, 0.0)
if isinstance(module, nn.LayerNorm): if isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
......
...@@ -552,12 +552,16 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -552,12 +552,16 @@ class XLNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -656,15 +656,19 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -656,15 +656,19 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment