Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -249,7 +249,9 @@ class TextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -886,7 +888,6 @@ class ViltPooler(nn.Module):
VILT_START_DOCSTRING,
)
class ViltForMaskedLM(ViltPreTrainedModel):
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
_tied_weights_keys = ["mlm_score.decoder.weight", "mlm_score.decoder.bias"]
def __init__(self, config):
......@@ -1419,8 +1420,6 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
VILT_START_DOCSTRING,
)
class ViltForTokenClassification(ViltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......
......@@ -78,7 +78,9 @@ class VisualBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
# For Visual Features
# Token type and position embedding for image features
......@@ -531,7 +533,6 @@ class VisualBertPreTrainedModel(PreTrainedModel):
config_class = VisualBertConfig
base_model_prefix = "visual_bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -871,7 +872,6 @@ class VisualBertModel(VisualBertPreTrainedModel):
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForPreTraining(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1462,7 +1462,6 @@ class VisualBertRegionToPhraseAttention(nn.Module):
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config):
......
......@@ -1089,7 +1089,6 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -1087,7 +1087,6 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2ConformerConfig
base_model_prefix = "wav2vec2_conformer"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -974,7 +974,6 @@ class WavLMPreTrainedModel(PreTrainedModel):
config_class = WavLMConfig
base_model_prefix = "wavlm"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -1225,8 +1225,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
WHISPER_START_DOCSTRING,
)
class WhisperModel(WhisperPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"proj_out.weight"]
def __init__(self, config: WhisperConfig):
super().__init__(config)
......@@ -1396,14 +1394,6 @@ class WhisperModel(WhisperPreTrainedModel):
)
class WhisperForConditionalGeneration(WhisperPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"proj_out.weight",
]
_keys_to_ignore_on_save = [
r"proj_out.weight",
]
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperConfig):
......
......@@ -139,7 +139,7 @@ class XCLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -162,7 +162,9 @@ class XCLIPTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -481,7 +483,6 @@ class XCLIPPreTrainedModel(PreTrainedModel):
config_class = XCLIPConfig
base_model_prefix = "x_clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -749,14 +749,6 @@ class XGLMModel(XGLMPreTrainedModel):
)
class XGLMForCausalLM(XGLMPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"model.embed_positions.weights",
r"embed_positions.weights",
r"lm_head.weight",
]
_keys_to_ignore_on_save = [
r"model.embed_positions.weights",
]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -391,8 +391,6 @@ XLM_INPUTS_DOCSTRING = r"""
XLM_START_DOCSTRING,
)
class XLMModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -461,7 +459,9 @@ class XLMModel(XLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def get_input_embeddings(self):
return self.embeddings
......@@ -670,7 +670,6 @@ class XLMPredLayer(nn.Module):
XLM_START_DOCSTRING,
)
class XLMWithLMHeadModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config):
......
......@@ -1768,7 +1768,6 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: XLMProphetNetConfig):
......@@ -1899,11 +1898,6 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig):
......@@ -2119,7 +2113,6 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig):
......
......@@ -81,7 +81,9 @@ class XLMRobertaEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -616,15 +618,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
if isinstance(module, XLMRobertaEncoder):
module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
XLM_ROBERTA_START_DOCSTRING = r"""
......@@ -713,8 +706,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
......@@ -885,9 +876,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -899,9 +887,6 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
self.lm_head = XLMRobertaLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1044,9 +1029,6 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -1061,9 +1043,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
self.lm_head = XLMRobertaLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1183,8 +1162,6 @@ class XLMRobertaLMHead(nn.Module):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1285,8 +1262,6 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForMultipleChoice(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1382,9 +1357,6 @@ class XLMRobertaForMultipleChoice(XLMRobertaPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForTokenClassification(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1494,9 +1466,6 @@ class XLMRobertaClassificationHead(nn.Module):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
class XLMRobertaForQuestionAnswering(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -73,7 +73,9 @@ class XLMRobertaXLEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -599,15 +601,6 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
XLM_ROBERTA_XL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
......@@ -679,8 +672,6 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRobertaXL
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
......@@ -850,9 +841,6 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
XLM_ROBERTA_XL_START_DOCSTRING,
)
class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -864,9 +852,6 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False)
self.lm_head = XLMRobertaXLLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
self.init_weights()
def get_output_embeddings(self):
......@@ -1001,9 +986,6 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
"""XLM-RoBERTa-xlarge Model with a `language modeling` head on top.""", XLM_ROBERTA_XL_START_DOCSTRING
)
class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -1018,9 +1000,6 @@ class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel):
self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False)
self.lm_head = XLMRobertaXLLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
self.init_weights()
def get_output_embeddings(self):
......@@ -1129,8 +1108,6 @@ class XLMRobertaXLLMHead(nn.Module):
XLM_ROBERTA_XL_START_DOCSTRING,
)
class XLMRobertaXLForSequenceClassification(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1225,8 +1202,6 @@ class XLMRobertaXLForSequenceClassification(XLMRobertaXLPreTrainedModel):
XLM_ROBERTA_XL_START_DOCSTRING,
)
class XLMRobertaXLForMultipleChoice(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1318,9 +1293,6 @@ class XLMRobertaXLForMultipleChoice(XLMRobertaXLPreTrainedModel):
XLM_ROBERTA_XL_START_DOCSTRING,
)
class XLMRobertaXLForTokenClassification(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1432,9 +1404,6 @@ class XLMRobertaXLClassificationHead(nn.Module):
XLM_ROBERTA_XL_START_DOCSTRING,
)
class XLMRobertaXLForQuestionAnswering(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -1292,7 +1292,6 @@ class XLNetModel(XLNetPreTrainedModel):
XLNET_START_DOCSTRING,
)
class XLNetLMHeadModel(XLNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_loss.weight"]
_tied_weights_keys = ["lm_loss.weight"]
def __init__(self, config):
......
......@@ -74,7 +74,9 @@ class XmodEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -682,16 +684,6 @@ class XmodPreTrainedModel(PreTrainedModel):
if isinstance(module, XmodEncoder):
module.gradient_checkpointing = value
# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel.update_keys_to_ignore
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
def set_default_language(self, language: str):
"""
Set the default language code for the model. This is used when the language is not specified in the input.
......@@ -811,8 +803,6 @@ class XmodModel(XmodPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Xmod
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
......@@ -989,9 +979,6 @@ class XmodModel(XmodPreTrainedModel):
XMOD_START_DOCSTRING,
)
class XmodForCausalLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod
......@@ -1004,9 +991,6 @@ class XmodForCausalLM(XmodPreTrainedModel):
self.roberta = XmodModel(config, add_pooling_layer=False)
self.lm_head = XmodLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1152,9 +1136,6 @@ class XmodForCausalLM(XmodPreTrainedModel):
XMOD_START_DOCSTRING,
)
class XmodForMaskedLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod
......@@ -1170,9 +1151,6 @@ class XmodForMaskedLM(XmodPreTrainedModel):
self.roberta = XmodModel(config, add_pooling_layer=False)
self.lm_head = XmodLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1285,8 +1263,6 @@ class XmodLMHead(nn.Module):
XMOD_START_DOCSTRING,
)
class XmodForSequenceClassification(XmodPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Xmod
def __init__(self, config):
super().__init__(config)
......@@ -1380,8 +1356,6 @@ class XmodForSequenceClassification(XmodPreTrainedModel):
XMOD_START_DOCSTRING,
)
class XmodForMultipleChoice(XmodPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice.__init__ with Roberta->Xmod
def __init__(self, config):
super().__init__(config)
......@@ -1471,9 +1445,6 @@ class XmodForMultipleChoice(XmodPreTrainedModel):
XMOD_START_DOCSTRING,
)
class XmodForTokenClassification(XmodPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Xmod
def __init__(self, config):
super().__init__(config)
......@@ -1576,9 +1547,6 @@ class XmodClassificationHead(nn.Module):
XMOD_START_DOCSTRING,
)
class XmodForQuestionAnswering(XmodPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Xmod
def __init__(self, config):
super().__init__(config)
......
......@@ -252,7 +252,9 @@ class YosoEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids",
......@@ -649,7 +651,6 @@ class YosoPreTrainedModel(PreTrainedModel):
config_class = YosoConfig
base_model_prefix = "yoso"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -849,11 +850,6 @@ class YosoModel(YosoPreTrainedModel):
@add_start_docstrings("""YOSO Model with a `language modeling` head on top.""", YOSO_START_DOCSTRING)
class YosoForMaskedLM(YosoPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......
......@@ -15,7 +15,6 @@
import unittest
from copy import deepcopy
from transformers import RobertaConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
......@@ -579,23 +578,3 @@ class RobertaModelIntegrationTest(TestCasePlus):
# expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
# XXX: this might be a candidate for common tests if we have many of those
def test_lm_head_ignore_keys(self):
keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
config = RobertaConfig.from_pretrained(ROBERTA_TINY)
config_tied = deepcopy(config)
config_tied.tie_word_embeddings = True
config_untied = deepcopy(config)
config_untied.tie_word_embeddings = False
for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
model = cls(config_tied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)
# the keys should be different when embeddings aren't tied
model = cls(config_untied)
self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)
# test that saving works with updated ignore keys - just testing that it doesn't fail
model.save_pretrained(self.get_auto_remove_tmp_dir())
......@@ -1562,7 +1562,7 @@ class ModelTesterMixin:
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
......@@ -1579,6 +1579,8 @@ class ModelTesterMixin:
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], [])
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
......@@ -1595,6 +1597,25 @@ class ModelTesterMixin:
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_load_save_without_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = False
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as d:
model.save_pretrained(d)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], [])
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True
......@@ -1620,55 +1641,72 @@ class ModelTesterMixin:
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(tied_params, [])
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def test_model_weights_reload_no_missing_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
model_tied.save_pretrained(d)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f:
torch.save({}, f)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
# ```python
# model = GPT2LMHeadModel.from_pretrained("gpt2")
# "lm_head.weight" in model.state_dict().keys() # True
# "lm_head.weight" in model.named_parameters() # False
# In [6]: model.lm_head.weight.data_ptr()
# Out[6]: 139901378371648
# In [9]: model.transformer.wte.weight.data_ptr()
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
# ```
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
missing_keys = set(infos["missing_keys"])
extra_missing = missing_keys - param_names
# missed_missing = param_names - missing_keys
# Remove tied weights from extra missing: they are normally not warned as missing if their tied
# counterpart is present but here there are no weights at all so we do get the warning.
ptrs = collections.defaultdict(list)
for name, tensor in model_reloaded.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
# We remove the group from extra_missing if not all weights from group are in it
if len(group - extra_missing) > 0:
extra_missing = extra_missing - set(group)
self.assertEqual(
extra_missing,
set(),
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}. "
f"For debugging, tied parameters are {tied_params}",
)
# self.assertEqual(
# missed_missing,
# set(),
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
# " parameters",
# )
missed_missing = param_names - missing_keys
# Remove nonpersistent buffers from missed_missing
buffers = [n for n, _ in model_reloaded.named_buffers()]
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
nonpersistent_buffers = {
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
}
missed_missing = missed_missing - nonpersistent_buffers
if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set()
else:
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
self.assertEqual(
missed_missing,
expected_missing,
f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
" parameters. If they are non persistent buffers make sure to instantiate them with"
" `persistent=False`",
)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -500,8 +500,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
for i in range(1, 5):
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["bin"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
......@@ -546,8 +546,8 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(os.path.isfile(weights_index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
for i in range(1, 6):
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
for i in range(1, 5):
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00005"] + ["safetensors"])
weights_name_file = os.path.join(tmp_dir, weights_name)
self.assertTrue(os.path.isfile(weights_name_file))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment