"docker/vscode:/vscode.git/clone" did not exist on "0665355a7cc45f70ac90499a536777b492a2927d"
Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -121,7 +121,9 @@ class ChineseCLIPTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -190,7 +192,7 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -689,7 +691,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
config_class = ChineseCLIPConfig
base_model_prefix = "chinese_clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -1166,7 +1166,9 @@ class ClapTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True
)
......@@ -1677,7 +1679,6 @@ class ClapPreTrainedModel(PreTrainedModel):
config_class = ClapConfig
base_model_prefix = "clap"
supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids", r"logit_scale_a", r"logit_scale_t"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1781,7 +1782,6 @@ class ClapTextModel(ClapPreTrainedModel):
"""
config_class = ClapTextConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText
def __init__(self, config, add_pooling_layer=True):
......@@ -1936,7 +1936,6 @@ class ClapTextModel(ClapPreTrainedModel):
@add_start_docstrings(CLAP_START_DOCSTRING)
class ClapModel(ClapPreTrainedModel):
config_class = ClapConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config: ClapConfig):
super().__init__(config)
......
......@@ -188,7 +188,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -210,7 +210,9 @@ class CLIPTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -410,7 +412,6 @@ class CLIPPreTrainedModel(PreTrainedModel):
config_class = CLIPConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -181,7 +181,7 @@ class CLIPSegVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_position_embeddings(self, new_size):
if len(new_size) != 2:
......@@ -230,7 +230,9 @@ class CLIPSegTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -433,7 +435,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
config_class = CLIPSegConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -83,6 +83,7 @@ class CodeGenAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
......@@ -600,7 +601,6 @@ class CodeGenModel(CodeGenPreTrainedModel):
CODEGEN_START_DOCSTRING,
)
class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -191,7 +191,9 @@ class ConvBertEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -245,8 +247,6 @@ class ConvBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_convbert
base_model_prefix = "convbert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"convbert.embeddings_project.weight", r"convbert.embeddings_project.bias"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -765,8 +765,6 @@ CONVBERT_INPUTS_DOCSTRING = r"""
CONVBERT_START_DOCSTRING,
)
class ConvBertModel(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.embeddings = ConvBertEmbeddings(config)
......@@ -880,7 +878,6 @@ class ConvBertGeneratorPredictions(nn.Module):
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
_tied_weights_keys = ["generator.lm_head.weight"]
def __init__(self, config):
......@@ -992,8 +989,6 @@ class ConvBertClassificationHead(nn.Module):
CONVBERT_START_DOCSTRING,
)
class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1089,8 +1084,6 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1184,8 +1177,6 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForTokenClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1267,8 +1258,6 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -537,7 +537,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
config_class = CpmAntConfig
base_model_prefix = "cpmant"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -749,7 +748,6 @@ class CpmAntModel(CpmAntPreTrainedModel):
CPMANT_START_DOCSTRING,
)
class CpmAntForCausalLM(CpmAntPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: CpmAntConfig):
......
......@@ -509,7 +509,6 @@ class CTRLModel(CTRLPreTrainedModel):
CTRL_START_DOCSTRING,
)
class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -689,7 +689,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
config_class = Data2VecAudioConfig
base_model_prefix = "data2vec_audio"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -80,7 +80,9 @@ class Data2VecTextForTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -615,15 +617,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
if isinstance(module, Data2VecTextEncoder):
module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
DATA2VECTEXT_START_DOCSTRING = r"""
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
......@@ -714,8 +707,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -883,9 +874,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
"""Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING
)
class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -897,9 +885,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
self.lm_head = Data2VecTextLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1038,9 +1023,6 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
@add_start_docstrings("""data2vec Model with a `language modeling` head on top.""", DATA2VECTEXT_START_DOCSTRING)
class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -1055,9 +1037,6 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
self.lm_head = Data2VecTextLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1174,8 +1153,6 @@ class Data2VecTextLMHead(nn.Module):
DATA2VECTEXT_START_DOCSTRING,
)
class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1273,8 +1250,6 @@ class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
DATA2VECTEXT_START_DOCSTRING,
)
class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1369,9 +1344,6 @@ class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
DATA2VECTEXT_START_DOCSTRING,
)
class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1478,9 +1450,6 @@ class Data2VecTextClassificationHead(nn.Module):
DATA2VECTEXT_START_DOCSTRING,
)
class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -470,7 +470,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
......
......@@ -764,7 +764,9 @@ class DebertaEmbeddings(nn.Module):
self.config = config
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
if input_ids is not None:
......@@ -821,7 +823,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
config_class = DebertaConfig
base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True
......@@ -1020,8 +1021,6 @@ class DebertaModel(DebertaPreTrainedModel):
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1277,8 +1276,6 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
DEBERTA_START_DOCSTRING,
)
class DebertaForTokenClassification(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1352,8 +1349,6 @@ class DebertaForTokenClassification(DebertaPreTrainedModel):
DEBERTA_START_DOCSTRING,
)
class DebertaForQuestionAnswering(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -862,7 +862,9 @@ class DebertaV2Embeddings(nn.Module):
self.config = config
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
if input_ids is not None:
......@@ -920,7 +922,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
config_class = DebertaV2Config
base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True
......@@ -1120,8 +1121,6 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1380,8 +1379,6 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
)
# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1455,8 +1452,6 @@ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
DEBERTA_START_DOCSTRING,
)
class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -476,8 +476,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
......@@ -747,8 +745,6 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "decision_transformer"
main_input_name = "states"
supports_gradient_checkpointing = False
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -1823,7 +1823,6 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
)
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
def __init__(self, config: DeformableDetrConfig):
......
......@@ -1775,7 +1775,6 @@ class DetaModel(DetaPreTrainedModel):
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.\d+"]
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
......
......@@ -595,7 +595,6 @@ class DistilBertModel(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
_tied_weights_keys = ["vocab_projector.weight"]
def __init__(self, config: PretrainedConfig):
......
......@@ -296,8 +296,6 @@ class DPRPretrainedContextEncoder(DPRPreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "ctx_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
......@@ -309,8 +307,6 @@ class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "question_encoder"
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
class DPRPretrainedReader(DPRPreTrainedModel):
......@@ -322,7 +318,6 @@ class DPRPretrainedReader(DPRPreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "span_predictor"
_keys_to_ignore_on_load_missing = [r"position_ids"]
###############
......
......@@ -161,7 +161,9 @@ class ElectraEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
......@@ -672,8 +674,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"electra.embeddings_project.weight", r"electra.embeddings_project.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......@@ -1166,7 +1166,6 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
ELECTRA_START_DOCSTRING,
)
class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config):
......@@ -1534,7 +1533,6 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
)
class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config):
......
......@@ -89,7 +89,9 @@ class ErnieEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -661,7 +663,6 @@ class ErniePreTrainedModel(PreTrainedModel):
config_class = ErnieConfig
base_model_prefix = "ernie"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -983,7 +984,6 @@ class ErnieModel(ErniePreTrainedModel):
ERNIE_START_DOCSTRING,
)
class ErnieForPreTraining(ErniePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
......@@ -1095,8 +1095,6 @@ class ErnieForPreTraining(ErniePreTrainedModel):
"""Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING
)
class ErnieForCausalLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
......@@ -1243,8 +1241,6 @@ class ErnieForCausalLM(ErniePreTrainedModel):
@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING)
class ErnieForMaskedLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
......@@ -1665,8 +1661,6 @@ class ErnieForMultipleChoice(ErniePreTrainedModel):
ERNIE_START_DOCSTRING,
)
class ErnieForTokenClassification(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
super().__init__(config)
......@@ -1746,8 +1740,6 @@ class ErnieForTokenClassification(ErniePreTrainedModel):
ERNIE_START_DOCSTRING,
)
class ErnieForQuestionAnswering(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment