Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -320,8 +320,9 @@ def shard_checkpoint(
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split.
if last_block_size + weight_size > max_shard_size:
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
sharded_state_dicts.append({})
last_block_size = 0
......@@ -3044,15 +3045,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
if remove_prefix_from_model:
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)
if is_accelerate_available():
model.tie_weights()
tied_params = find_tied_parameters(model)
else:
tied_params = []
model.tie_weights()
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
if remove_prefix_from_model:
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
elif add_prefix_to_model:
group = [".".join([prefix, key]) for key in group]
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
......
......@@ -208,7 +208,9 @@ class AlbertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
......@@ -507,7 +509,6 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -760,11 +761,6 @@ class AlbertModel(AlbertPreTrainedModel):
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config: AlbertConfig):
super().__init__(config)
......@@ -912,13 +908,7 @@ class AlbertSOPHead(nn.Module):
ALBERT_START_DOCSTRING,
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config):
super().__init__(config)
......@@ -1133,8 +1123,6 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1218,8 +1206,6 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -687,7 +687,9 @@ class AlignTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -1176,7 +1178,6 @@ class AlignPreTrainedModel(PreTrainedModel):
config_class = AlignConfig
base_model_prefix = "align"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -216,7 +216,9 @@ class AltRobertaEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -1016,7 +1018,7 @@ class AltCLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -1038,7 +1040,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
config_class = AltCLIPConfig
base_model_prefix = "altclip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -506,7 +506,7 @@ class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
......@@ -1170,7 +1170,6 @@ class BartDecoder(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig):
......@@ -1300,12 +1299,7 @@ class BartModel(BartPretrainedModel):
class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: BartConfig):
super().__init__(config)
......@@ -1478,7 +1472,6 @@ class BartForConditionalGeneration(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs):
......@@ -1609,7 +1602,6 @@ class BartForSequenceClassification(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
......@@ -1748,7 +1740,6 @@ class BartDecoderWrapper(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -459,7 +459,7 @@ class BeitRelativePositionBias(nn.Module):
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
......
......@@ -192,7 +192,9 @@ class BertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -743,7 +745,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1053,7 +1054,6 @@ class BertModel(BertPreTrainedModel):
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......@@ -1160,8 +1160,6 @@ class BertForPreTraining(BertPreTrainedModel):
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......@@ -1301,8 +1299,6 @@ class BertLMHeadModel(BertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......@@ -1715,8 +1711,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1800,8 +1794,6 @@ class BertForTokenClassification(BertPreTrainedModel):
BERT_START_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -556,7 +556,9 @@ class BertGenerationEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
if input_ids is not None:
......@@ -588,7 +590,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -860,7 +861,6 @@ class BertGenerationOnlyLMHead(nn.Module):
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......
......@@ -257,7 +257,9 @@ class BigBirdEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -1765,7 +1767,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -2261,7 +2262,6 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -2368,7 +2368,6 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -2513,12 +2512,6 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......
......@@ -2358,7 +2358,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig):
......@@ -2491,12 +2490,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
......@@ -2669,7 +2663,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
......@@ -2799,7 +2792,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING,
)
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
......@@ -2932,7 +2924,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -646,7 +646,6 @@ class BioGptModel(BioGptPreTrainedModel):
"""BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING
)
class BioGptForCausalLM(BioGptPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"]
def __init__(self, config):
......
......@@ -1102,7 +1102,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
BLENDERBOT_START_DOCSTRING,
)
class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotConfig):
......@@ -1244,14 +1243,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
)
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotConfig):
......@@ -1441,7 +1433,6 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -1096,7 +1096,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
BLENDERBOT_SMALL_START_DOCSTRING,
)
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotSmallConfig):
......@@ -1226,14 +1225,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
)
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotSmallConfig):
......@@ -1408,7 +1400,6 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -255,7 +255,9 @@ class BlipTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -419,7 +421,6 @@ class BlipPreTrainedModel(PreTrainedModel):
config_class = BlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -927,7 +928,6 @@ class BlipModel(BlipPreTrainedModel):
)
class BlipForConditionalGeneration(BlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values"
......@@ -1100,7 +1100,6 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
)
class BlipForQuestionAnswering(BlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig):
......
......@@ -56,7 +56,9 @@ class BlipTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
......@@ -552,7 +554,6 @@ class BlipTextPreTrainedModel(PreTrainedModel):
config_class = BlipTextConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -808,9 +809,6 @@ class BlipTextModel(BlipTextPreTrainedModel):
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class BlipTextLMHeadModel(BlipTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -273,12 +273,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
config_class = Blip2Config
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"]
......
......@@ -471,12 +471,6 @@ class BloomBlock(nn.Module):
class BloomPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BloomConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
......@@ -826,7 +820,6 @@ class BloomModel(BloomPreTrainedModel):
BLOOM_START_DOCSTRING,
)
class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: BloomConfig):
......@@ -995,8 +988,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config: BloomConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1123,8 +1114,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
BLOOM_START_DOCSTRING,
)
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config: BloomConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1226,8 +1215,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
BLOOM_START_DOCSTRING,
)
class BloomForQuestionAnswering(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
......
......@@ -280,7 +280,7 @@ class BridgeTowerVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -880,7 +880,9 @@ class BridgeTowerTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -1038,8 +1040,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
config_class = BridgeTowerTextConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......
......@@ -94,7 +94,9 @@ class CamembertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -627,15 +629,6 @@ class CamembertPreTrainedModel(PreTrainedModel):
if isinstance(module, CamembertEncoder):
module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
CAMEMBERT_INPUTS_DOCSTRING = r"""
Args:
......@@ -762,7 +755,6 @@ class CamembertModel(CamembertPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
......@@ -935,9 +927,6 @@ class CamembertModel(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForMaskedLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -952,9 +941,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
self.roberta = CamembertModel(config, add_pooling_layer=False)
self.lm_head = CamembertLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1042,8 +1028,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForSequenceClassification(CamembertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1144,8 +1128,6 @@ class CamembertForSequenceClassification(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForMultipleChoice(CamembertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1241,9 +1223,6 @@ class CamembertForMultipleChoice(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForTokenClassification(CamembertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1330,9 +1309,6 @@ class CamembertForTokenClassification(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForQuestionAnswering(CamembertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1431,9 +1407,6 @@ class CamembertForQuestionAnswering(CamembertPreTrainedModel):
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, roberta-base->camembert-base
class CamembertForCausalLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......@@ -1445,9 +1418,6 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
self.roberta = CamembertModel(config, add_pooling_layer=False)
self.lm_head = CamembertLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
......
......@@ -216,7 +216,9 @@ class CanineEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int):
......@@ -900,7 +902,6 @@ class CaninePreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_canine
base_model_prefix = "canine"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment