"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "829b9f8cc321aa28396e6203e0f21eed26b132f7"
Unverified Commit 7da995c0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix embeddings for PyTorch 1.8 (#10549)

* Fix embeddings for PyTorch 1.8

* Try with PyTorch 1.8.0

* Fix embeddings init

* Fix copies

* Typo

* More typos
parent 3e056c10
...@@ -79,8 +79,7 @@ jobs: ...@@ -79,8 +79,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece] - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install -U torch==1.7.1
- save_cache: - save_cache:
key: v0.4-{{ checksum "setup.py" }} key: v0.4-{{ checksum "setup.py" }}
paths: paths:
...@@ -107,8 +106,7 @@ jobs: ...@@ -107,8 +106,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece] - run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install -U torch==1.7.1
- save_cache: - save_cache:
key: v0.4-torch-{{ checksum "setup.py" }} key: v0.4-torch-{{ checksum "setup.py" }}
paths: paths:
...@@ -187,8 +185,7 @@ jobs: ...@@ -187,8 +185,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece] - run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install -U torch==1.7.1
- save_cache: - save_cache:
key: v0.4-torch-{{ checksum "setup.py" }} key: v0.4-torch-{{ checksum "setup.py" }}
paths: paths:
......
...@@ -490,12 +490,16 @@ class AlbertPreTrainedModel(PreTrainedModel): ...@@ -490,12 +490,16 @@ class AlbertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear)) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -704,15 +704,19 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -704,15 +704,19 @@ class BertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass @dataclass
......
...@@ -177,15 +177,19 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -177,15 +177,19 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
BERT_GENERATION_START_DOCSTRING = r""" BERT_GENERATION_START_DOCSTRING = r"""
......
...@@ -238,15 +238,19 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -238,15 +238,19 @@ class ConvBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, (nn.Linear)) and module.bias is not None:
module.bias.data.zero_()
class SeparableConv1D(nn.Module): class SeparableConv1D(nn.Module):
......
...@@ -221,12 +221,16 @@ class CTRLPreTrainedModel(PreTrainedModel): ...@@ -221,12 +221,16 @@ class CTRLPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -767,13 +767,17 @@ class DebertaPreTrainedModel(PreTrainedModel): ...@@ -767,13 +767,17 @@ class DebertaPreTrainedModel(PreTrainedModel):
self._register_load_state_dict_pre_hook(self._pre_load_hook) self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
""" """
......
...@@ -886,13 +886,17 @@ class DebertaV2PreTrainedModel(PreTrainedModel): ...@@ -886,13 +886,17 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
self._register_load_state_dict_pre_hook(self._pre_load_hook) self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
""" """
......
...@@ -341,16 +341,19 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -341,16 +341,19 @@ class DistilBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, nn.Embedding):
if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
DISTILBERT_START_DOCSTRING = r""" DISTILBERT_START_DOCSTRING = r"""
......
...@@ -653,15 +653,19 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -653,15 +653,19 @@ class ElectraPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass @dataclass
......
...@@ -779,6 +779,8 @@ class FunnelPreTrainedModel(PreTrainedModel): ...@@ -779,6 +779,8 @@ class FunnelPreTrainedModel(PreTrainedModel):
elif classname == "FunnelEmbeddings": elif classname == "FunnelEmbeddings":
std = 1.0 if self.config.initializer_std is None else self.config.initializer_std std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
nn.init.normal_(module.word_embeddings.weight, std=std) nn.init.normal_(module.word_embeddings.weight, std=std)
if module.word_embeddings.padding_idx is not None:
module.word_embeddings.weight.data[module.padding_idx].zero_()
class FunnelClassificationHead(nn.Module): class FunnelClassificationHead(nn.Module):
......
...@@ -345,12 +345,16 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -345,12 +345,16 @@ class GPT2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -645,15 +645,19 @@ class IBertPreTrainedModel(PreTrainedModel): ...@@ -645,15 +645,19 @@ class IBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (QuantLinear, QuantEmbedding, nn.Linear, nn.Embedding)): if isinstance(module, (QuantLinear, nn.Linear)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (QuantEmbedding, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (IntLayerNorm, nn.LayerNorm)): elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, (QuantLinear, nn.Linear)) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.") raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")
......
...@@ -612,15 +612,19 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -612,15 +612,19 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayoutLMLayerNorm): elif isinstance(module, LayoutLMLayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LAYOUTLM_START_DOCSTRING = r""" LAYOUTLM_START_DOCSTRING = r"""
......
...@@ -1363,15 +1363,19 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1363,15 +1363,19 @@ class LongformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LONGFORMER_START_DOCSTRING = r""" LONGFORMER_START_DOCSTRING = r"""
......
...@@ -783,15 +783,19 @@ class LxmertPreTrainedModel(PreTrainedModel): ...@@ -783,15 +783,19 @@ class LxmertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
LXMERT_START_DOCSTRING = r""" LXMERT_START_DOCSTRING = r"""
......
...@@ -670,15 +670,19 @@ class MobileBertPreTrainedModel(PreTrainedModel): ...@@ -670,15 +670,19 @@ class MobileBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, NoNorm)): elif isinstance(module, (nn.LayerNorm, NoNorm)):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass @dataclass
......
...@@ -56,15 +56,19 @@ class MPNetPreTrainedModel(PreTrainedModel): ...@@ -56,15 +56,19 @@ class MPNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class MPNetEmbeddings(nn.Module): class MPNetEmbeddings(nn.Module):
......
...@@ -283,12 +283,16 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -283,12 +283,16 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -1791,16 +1791,17 @@ class ReformerPreTrainedModel(PreTrainedModel): ...@@ -1791,16 +1791,17 @@ class ReformerPreTrainedModel(PreTrainedModel):
torch.nn.init.normal_(weight, std=self.config.axial_norm_std) torch.nn.init.normal_(weight, std=self.config.axial_norm_std)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment