Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
...@@ -121,7 +121,9 @@ class ChineseCLIPTextEmbeddings(nn.Module): ...@@ -121,7 +121,9 @@ class ChineseCLIPTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -190,7 +192,7 @@ class ChineseCLIPVisionEmbeddings(nn.Module): ...@@ -190,7 +192,7 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -689,7 +691,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): ...@@ -689,7 +691,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
config_class = ChineseCLIPConfig config_class = ChineseCLIPConfig
base_model_prefix = "chinese_clip" base_model_prefix = "chinese_clip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -1166,7 +1166,9 @@ class ClapTextEmbeddings(nn.Module): ...@@ -1166,7 +1166,9 @@ class ClapTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True
) )
...@@ -1677,7 +1679,6 @@ class ClapPreTrainedModel(PreTrainedModel): ...@@ -1677,7 +1679,6 @@ class ClapPreTrainedModel(PreTrainedModel):
config_class = ClapConfig config_class = ClapConfig
base_model_prefix = "clap" base_model_prefix = "clap"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -1781,7 +1782,6 @@ class ClapTextModel(ClapPreTrainedModel): ...@@ -1781,7 +1782,6 @@ class ClapTextModel(ClapPreTrainedModel):
""" """
config_class = ClapTextConfig config_class = ClapTextConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
...@@ -1936,7 +1936,6 @@ class ClapTextModel(ClapPreTrainedModel): ...@@ -1936,7 +1936,6 @@ class ClapTextModel(ClapPreTrainedModel):
@add_start_docstrings(CLAP_START_DOCSTRING) @add_start_docstrings(CLAP_START_DOCSTRING)
class ClapModel(ClapPreTrainedModel): class ClapModel(ClapPreTrainedModel):
config_class = ClapConfig config_class = ClapConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config: ClapConfig): def __init__(self, config: ClapConfig):
super().__init__(config) super().__init__(config)
......
...@@ -188,7 +188,7 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -188,7 +188,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -210,7 +210,9 @@ class CLIPTextEmbeddings(nn.Module): ...@@ -210,7 +210,9 @@ class CLIPTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward( def forward(
self, self,
...@@ -410,7 +412,6 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -410,7 +412,6 @@ class CLIPPreTrainedModel(PreTrainedModel):
config_class = CLIPConfig config_class = CLIPConfig
base_model_prefix = "clip" base_model_prefix = "clip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -181,7 +181,7 @@ class CLIPSegVisionEmbeddings(nn.Module): ...@@ -181,7 +181,7 @@ class CLIPSegVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_position_embeddings(self, new_size): def interpolate_position_embeddings(self, new_size):
if len(new_size) != 2: if len(new_size) != 2:
...@@ -230,7 +230,9 @@ class CLIPSegTextEmbeddings(nn.Module): ...@@ -230,7 +230,9 @@ class CLIPSegTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward( def forward(
self, self,
...@@ -433,7 +435,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel): ...@@ -433,7 +435,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
config_class = CLIPSegConfig config_class = CLIPSegConfig
base_model_prefix = "clip" base_model_prefix = "clip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -83,6 +83,7 @@ class CodeGenAttention(nn.Module): ...@@ -83,6 +83,7 @@ class CodeGenAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
persistent=False,
) )
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -600,7 +601,6 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -600,7 +601,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
CODEGEN_START_DOCSTRING, CODEGEN_START_DOCSTRING,
) )
class CodeGenForCausalLM(CodeGenPreTrainedModel): class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -191,7 +191,9 @@ class ConvBertEmbeddings(nn.Module): ...@@ -191,7 +191,9 @@ class ConvBertEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -245,8 +247,6 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -245,8 +247,6 @@ class ConvBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_convbert load_tf_weights = load_tf_weights_in_convbert
base_model_prefix = "convbert" base_model_prefix = "convbert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"convbert.embeddings_project.weight", r"convbert.embeddings_project.bias"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -765,8 +765,6 @@ CONVBERT_INPUTS_DOCSTRING = r""" ...@@ -765,8 +765,6 @@ CONVBERT_INPUTS_DOCSTRING = r"""
CONVBERT_START_DOCSTRING, CONVBERT_START_DOCSTRING,
) )
class ConvBertModel(ConvBertPreTrainedModel): class ConvBertModel(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.embeddings = ConvBertEmbeddings(config) self.embeddings = ConvBertEmbeddings(config)
...@@ -880,7 +878,6 @@ class ConvBertGeneratorPredictions(nn.Module): ...@@ -880,7 +878,6 @@ class ConvBertGeneratorPredictions(nn.Module):
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING) @add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel): class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
_tied_weights_keys = ["generator.lm_head.weight"] _tied_weights_keys = ["generator.lm_head.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -992,8 +989,6 @@ class ConvBertClassificationHead(nn.Module): ...@@ -992,8 +989,6 @@ class ConvBertClassificationHead(nn.Module):
CONVBERT_START_DOCSTRING, CONVBERT_START_DOCSTRING,
) )
class ConvBertForSequenceClassification(ConvBertPreTrainedModel): class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1089,8 +1084,6 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel): ...@@ -1089,8 +1084,6 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING, CONVBERT_START_DOCSTRING,
) )
class ConvBertForMultipleChoice(ConvBertPreTrainedModel): class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1184,8 +1177,6 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel): ...@@ -1184,8 +1177,6 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING, CONVBERT_START_DOCSTRING,
) )
class ConvBertForTokenClassification(ConvBertPreTrainedModel): class ConvBertForTokenClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1267,8 +1258,6 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel): ...@@ -1267,8 +1258,6 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING, CONVBERT_START_DOCSTRING,
) )
class ConvBertForQuestionAnswering(ConvBertPreTrainedModel): class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -537,7 +537,6 @@ class CpmAntPreTrainedModel(PreTrainedModel): ...@@ -537,7 +537,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
config_class = CpmAntConfig config_class = CpmAntConfig
base_model_prefix = "cpmant" base_model_prefix = "cpmant"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -749,7 +748,6 @@ class CpmAntModel(CpmAntPreTrainedModel): ...@@ -749,7 +748,6 @@ class CpmAntModel(CpmAntPreTrainedModel):
CPMANT_START_DOCSTRING, CPMANT_START_DOCSTRING,
) )
class CpmAntForCausalLM(CpmAntPreTrainedModel): class CpmAntForCausalLM(CpmAntPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: CpmAntConfig): def __init__(self, config: CpmAntConfig):
......
...@@ -509,7 +509,6 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -509,7 +509,6 @@ class CTRLModel(CTRLPreTrainedModel):
CTRL_START_DOCSTRING, CTRL_START_DOCSTRING,
) )
class CTRLLMHeadModel(CTRLPreTrainedModel): class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -689,7 +689,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): ...@@ -689,7 +689,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
config_class = Data2VecAudioConfig config_class = Data2VecAudioConfig
base_model_prefix = "data2vec_audio" base_model_prefix = "data2vec_audio"
main_input_name = "input_values" main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -80,7 +80,9 @@ class Data2VecTextForTextEmbeddings(nn.Module): ...@@ -80,7 +80,9 @@ class Data2VecTextForTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -615,15 +617,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): ...@@ -615,15 +617,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
if isinstance(module, Data2VecTextEncoder): if isinstance(module, Data2VecTextEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
DATA2VECTEXT_START_DOCSTRING = r""" DATA2VECTEXT_START_DOCSTRING = r"""
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
...@@ -714,8 +707,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel): ...@@ -714,8 +707,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
""" """
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -883,9 +874,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel): ...@@ -883,9 +874,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
"""Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING """Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING
) )
class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -897,9 +885,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -897,9 +885,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
self.lm_head = Data2VecTextLMHead(config) self.lm_head = Data2VecTextLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1038,9 +1023,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -1038,9 +1023,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
@add_start_docstrings("""data2vec Model with a `language modeling` head on top.""", DATA2VECTEXT_START_DOCSTRING) @add_start_docstrings("""data2vec Model with a `language modeling` head on top.""", DATA2VECTEXT_START_DOCSTRING)
class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1055,9 +1037,6 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): ...@@ -1055,9 +1037,6 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False) self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
self.lm_head = Data2VecTextLMHead(config) self.lm_head = Data2VecTextLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1174,8 +1153,6 @@ class Data2VecTextLMHead(nn.Module): ...@@ -1174,8 +1153,6 @@ class Data2VecTextLMHead(nn.Module):
DATA2VECTEXT_START_DOCSTRING, DATA2VECTEXT_START_DOCSTRING,
) )
class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel): class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1273,8 +1250,6 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel): ...@@ -1273,8 +1250,6 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
DATA2VECTEXT_START_DOCSTRING, DATA2VECTEXT_START_DOCSTRING,
) )
class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel): class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1369,9 +1344,6 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel): ...@@ -1369,9 +1344,6 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
DATA2VECTEXT_START_DOCSTRING, DATA2VECTEXT_START_DOCSTRING,
) )
class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel): class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1478,9 +1450,6 @@ class Data2VecTextClassificationHead(nn.Module): ...@@ -1478,9 +1450,6 @@ class Data2VecTextClassificationHead(nn.Module):
DATA2VECTEXT_START_DOCSTRING, DATA2VECTEXT_START_DOCSTRING,
) )
class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel): class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -470,7 +470,7 @@ class Data2VecVisionRelativePositionBias(nn.Module): ...@@ -470,7 +470,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1 relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor: def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
......
...@@ -764,7 +764,9 @@ class DebertaEmbeddings(nn.Module): ...@@ -764,7 +764,9 @@ class DebertaEmbeddings(nn.Module):
self.config = config self.config = config
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
...@@ -821,7 +823,6 @@ class DebertaPreTrainedModel(PreTrainedModel): ...@@ -821,7 +823,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
config_class = DebertaConfig config_class = DebertaConfig
base_model_prefix = "deberta" base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
...@@ -1020,8 +1021,6 @@ class DebertaModel(DebertaPreTrainedModel): ...@@ -1020,8 +1021,6 @@ class DebertaModel(DebertaPreTrainedModel):
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel): class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1277,8 +1276,6 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1277,8 +1276,6 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
DEBERTA_START_DOCSTRING, DEBERTA_START_DOCSTRING,
) )
class DebertaForTokenClassification(DebertaPreTrainedModel): class DebertaForTokenClassification(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1352,8 +1349,6 @@ class DebertaForTokenClassification(DebertaPreTrainedModel): ...@@ -1352,8 +1349,6 @@ class DebertaForTokenClassification(DebertaPreTrainedModel):
DEBERTA_START_DOCSTRING, DEBERTA_START_DOCSTRING,
) )
class DebertaForQuestionAnswering(DebertaPreTrainedModel): class DebertaForQuestionAnswering(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -862,7 +862,9 @@ class DebertaV2Embeddings(nn.Module): ...@@ -862,7 +862,9 @@ class DebertaV2Embeddings(nn.Module):
self.config = config self.config = config
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
...@@ -920,7 +922,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel): ...@@ -920,7 +922,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
config_class = DebertaV2Config config_class = DebertaV2Config
base_model_prefix = "deberta" base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
...@@ -1120,8 +1121,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel): ...@@ -1120,8 +1121,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING) @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1380,8 +1379,6 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): ...@@ -1380,8 +1379,6 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
) )
# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1455,8 +1452,6 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): ...@@ -1455,8 +1452,6 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
DEBERTA_START_DOCSTRING, DEBERTA_START_DOCSTRING,
) )
class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -476,8 +476,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): ...@@ -476,8 +476,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -747,8 +745,6 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel): ...@@ -747,8 +745,6 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "decision_transformer" base_model_prefix = "decision_transformer"
main_input_name = "states" main_input_name = "states"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -1823,7 +1823,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): ...@@ -1823,7 +1823,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
) )
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] _tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
def __init__(self, config: DeformableDetrConfig): def __init__(self, config: DeformableDetrConfig):
......
...@@ -1775,7 +1775,6 @@ class DetaModel(DetaPreTrainedModel): ...@@ -1775,7 +1775,6 @@ class DetaModel(DetaPreTrainedModel):
) )
class DetaForObjectDetection(DetaPreTrainedModel): class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.\d+"] _tied_weights_keys = [r"bbox_embed\.\d+"]
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
......
...@@ -595,7 +595,6 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -595,7 +595,6 @@ class DistilBertModel(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING, DISTILBERT_START_DOCSTRING,
) )
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
_tied_weights_keys = ["vocab_projector.weight"] _tied_weights_keys = ["vocab_projector.weight"]
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
......
...@@ -296,8 +296,6 @@ class DPRPretrainedContextEncoder(DPRPreTrainedModel): ...@@ -296,8 +296,6 @@ class DPRPretrainedContextEncoder(DPRPreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "ctx_encoder" base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel): class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
...@@ -309,8 +307,6 @@ class DPRPretrainedQuestionEncoder(DPRPreTrainedModel): ...@@ -309,8 +307,6 @@ class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "question_encoder" base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
class DPRPretrainedReader(DPRPreTrainedModel): class DPRPretrainedReader(DPRPreTrainedModel):
...@@ -322,7 +318,6 @@ class DPRPretrainedReader(DPRPreTrainedModel): ...@@ -322,7 +318,6 @@ class DPRPretrainedReader(DPRPreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "span_predictor" base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"]
############### ###############
......
...@@ -161,7 +161,9 @@ class ElectraEmbeddings(nn.Module): ...@@ -161,7 +161,9 @@ class ElectraEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
...@@ -672,8 +674,6 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -672,8 +674,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_electra load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra" base_model_prefix = "electra"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"electra.embeddings_project.weight", r"electra.embeddings_project.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1166,7 +1166,6 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -1166,7 +1166,6 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
ELECTRA_START_DOCSTRING, ELECTRA_START_DOCSTRING,
) )
class ElectraForMaskedLM(ElectraPreTrainedModel): class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"] _tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1534,7 +1533,6 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): ...@@ -1534,7 +1533,6 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
) )
class ElectraForCausalLM(ElectraPreTrainedModel): class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"] _tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -89,7 +89,9 @@ class ErnieEmbeddings(nn.Module): ...@@ -89,7 +89,9 @@ class ErnieEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -661,7 +663,6 @@ class ErniePreTrainedModel(PreTrainedModel): ...@@ -661,7 +663,6 @@ class ErniePreTrainedModel(PreTrainedModel):
config_class = ErnieConfig config_class = ErnieConfig
base_model_prefix = "ernie" base_model_prefix = "ernie"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -983,7 +984,6 @@ class ErnieModel(ErniePreTrainedModel): ...@@ -983,7 +984,6 @@ class ErnieModel(ErniePreTrainedModel):
ERNIE_START_DOCSTRING, ERNIE_START_DOCSTRING,
) )
class ErnieForPreTraining(ErniePreTrainedModel): class ErnieForPreTraining(ErniePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
...@@ -1095,8 +1095,6 @@ class ErnieForPreTraining(ErniePreTrainedModel): ...@@ -1095,8 +1095,6 @@ class ErnieForPreTraining(ErniePreTrainedModel):
"""Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING
) )
class ErnieForCausalLM(ErniePreTrainedModel): class ErnieForCausalLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
...@@ -1243,8 +1241,6 @@ class ErnieForCausalLM(ErniePreTrainedModel): ...@@ -1243,8 +1241,6 @@ class ErnieForCausalLM(ErniePreTrainedModel):
@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING) @add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING)
class ErnieForMaskedLM(ErniePreTrainedModel): class ErnieForMaskedLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
...@@ -1665,8 +1661,6 @@ class ErnieForMultipleChoice(ErniePreTrainedModel): ...@@ -1665,8 +1661,6 @@ class ErnieForMultipleChoice(ErniePreTrainedModel):
ERNIE_START_DOCSTRING, ERNIE_START_DOCSTRING,
) )
class ErnieForTokenClassification(ErniePreTrainedModel): class ErnieForTokenClassification(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1746,8 +1740,6 @@ class ErnieForTokenClassification(ErniePreTrainedModel): ...@@ -1746,8 +1740,6 @@ class ErnieForTokenClassification(ErniePreTrainedModel):
ERNIE_START_DOCSTRING, ERNIE_START_DOCSTRING,
) )
class ErnieForQuestionAnswering(ErniePreTrainedModel): class ErnieForQuestionAnswering(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment