Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
...@@ -320,8 +320,9 @@ def shard_checkpoint( ...@@ -320,8 +320,9 @@ def shard_checkpoint(
weight_size = weight.numel() * dtype_byte_size(weight.dtype) weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
if last_block_size + weight_size > max_shard_size: # weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
sharded_state_dicts.append({}) sharded_state_dicts.append({})
last_block_size = 0 last_block_size = 0
...@@ -3044,15 +3045,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3044,15 +3045,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = [".".join([prefix, s]) for s in expected_keys] expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys)) missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
if remove_prefix_from_model:
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)
if is_accelerate_available(): model.tie_weights()
model.tie_weights() ptrs = collections.defaultdict(list)
tied_params = find_tied_parameters(model) for name, tensor in model.state_dict().items():
else: id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor)
tied_params = [] ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params: for group in tied_params:
if remove_prefix_from_model:
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
elif add_prefix_to_model:
group = [".".join([prefix, key]) for key in group]
missing_in_group = [k for k in missing_keys if k in group] missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group): if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group] missing_keys = [k for k in missing_keys if k not in missing_in_group]
......
...@@ -208,7 +208,9 @@ class AlbertEmbeddings(nn.Module): ...@@ -208,7 +208,9 @@ class AlbertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
...@@ -507,7 +509,6 @@ class AlbertPreTrainedModel(PreTrainedModel): ...@@ -507,7 +509,6 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert" base_model_prefix = "albert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -760,11 +761,6 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -760,11 +761,6 @@ class AlbertModel(AlbertPreTrainedModel):
) )
class AlbertForPreTraining(AlbertPreTrainedModel): class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config: AlbertConfig): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
...@@ -912,13 +908,7 @@ class AlbertSOPHead(nn.Module): ...@@ -912,13 +908,7 @@ class AlbertSOPHead(nn.Module):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1133,8 +1123,6 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -1133,8 +1123,6 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForTokenClassification(AlbertPreTrainedModel): class AlbertForTokenClassification(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config: AlbertConfig): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1218,8 +1206,6 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1218,8 +1206,6 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForQuestionAnswering(AlbertPreTrainedModel): class AlbertForQuestionAnswering(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config: AlbertConfig): def __init__(self, config: AlbertConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -687,7 +687,9 @@ class AlignTextEmbeddings(nn.Module): ...@@ -687,7 +687,9 @@ class AlignTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -1176,7 +1178,6 @@ class AlignPreTrainedModel(PreTrainedModel): ...@@ -1176,7 +1178,6 @@ class AlignPreTrainedModel(PreTrainedModel):
config_class = AlignConfig config_class = AlignConfig
base_model_prefix = "align" base_model_prefix = "align"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -216,7 +216,9 @@ class AltRobertaEmbeddings(nn.Module): ...@@ -216,7 +216,9 @@ class AltRobertaEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -1016,7 +1018,7 @@ class AltCLIPVisionEmbeddings(nn.Module): ...@@ -1016,7 +1018,7 @@ class AltCLIPVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -1038,7 +1040,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel): ...@@ -1038,7 +1040,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
config_class = AltCLIPConfig config_class = AltCLIPConfig
base_model_prefix = "altclip" base_model_prefix = "altclip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -506,7 +506,7 @@ class BartPretrainedModel(PreTrainedModel): ...@@ -506,7 +506,7 @@ class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig config_class = BartConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
...@@ -1170,7 +1170,6 @@ class BartDecoder(BartPretrainedModel): ...@@ -1170,7 +1170,6 @@ class BartDecoder(BartPretrainedModel):
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartModel(BartPretrainedModel): class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
...@@ -1300,12 +1299,7 @@ class BartModel(BartPretrainedModel): ...@@ -1300,12 +1299,7 @@ class BartModel(BartPretrainedModel):
class BartForConditionalGeneration(BartPretrainedModel): class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
...@@ -1478,7 +1472,6 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1478,7 +1472,6 @@ class BartForConditionalGeneration(BartPretrainedModel):
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForSequenceClassification(BartPretrainedModel): class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
...@@ -1609,7 +1602,6 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1609,7 +1602,6 @@ class BartForSequenceClassification(BartPretrainedModel):
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForQuestionAnswering(BartPretrainedModel): class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1748,7 +1740,6 @@ class BartDecoderWrapper(BartPretrainedModel): ...@@ -1748,7 +1740,6 @@ class BartDecoderWrapper(BartPretrainedModel):
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForCausalLM(BartPretrainedModel): class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -459,7 +459,7 @@ class BeitRelativePositionBias(nn.Module): ...@@ -459,7 +459,7 @@ class BeitRelativePositionBias(nn.Module):
relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1 relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor: def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
......
...@@ -192,7 +192,9 @@ class BertEmbeddings(nn.Module): ...@@ -192,7 +192,9 @@ class BertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -743,7 +745,6 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -743,7 +745,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -1053,7 +1054,6 @@ class BertModel(BertPreTrainedModel): ...@@ -1053,7 +1054,6 @@ class BertModel(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1160,8 +1160,6 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -1160,8 +1160,6 @@ class BertForPreTraining(BertPreTrainedModel):
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
) )
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1301,8 +1299,6 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1301,8 +1299,6 @@ class BertLMHeadModel(BertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1715,8 +1711,6 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1715,8 +1711,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1800,8 +1794,6 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1800,8 +1794,6 @@ class BertForTokenClassification(BertPreTrainedModel):
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -556,7 +556,9 @@ class BertGenerationEmbeddings(nn.Module): ...@@ -556,7 +556,9 @@ class BertGenerationEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0): def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
if input_ids is not None: if input_ids is not None:
...@@ -588,7 +590,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -588,7 +590,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig config_class = BertGenerationConfig
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -860,7 +861,6 @@ class BertGenerationOnlyLMHead(nn.Module): ...@@ -860,7 +861,6 @@ class BertGenerationOnlyLMHead(nn.Module):
BERT_GENERATION_START_DOCSTRING, BERT_GENERATION_START_DOCSTRING,
) )
class BertGenerationDecoder(BertGenerationPreTrainedModel): class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
......
...@@ -257,7 +257,9 @@ class BigBirdEmbeddings(nn.Module): ...@@ -257,7 +257,9 @@ class BigBirdEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -1765,7 +1767,6 @@ class BigBirdPreTrainedModel(PreTrainedModel): ...@@ -1765,7 +1767,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_big_bird load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -2261,7 +2262,6 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2261,7 +2262,6 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel): class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -2368,7 +2368,6 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel): ...@@ -2368,7 +2368,6 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING) @add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel): class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -2513,12 +2512,6 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel): ...@@ -2513,12 +2512,6 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING """BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
) )
class BigBirdForCausalLM(BigBirdPreTrainedModel): class BigBirdForCausalLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
......
...@@ -2358,7 +2358,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2358,7 +2358,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING, BIGBIRD_PEGASUS_START_DOCSTRING,
) )
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig): def __init__(self, config: BigBirdPegasusConfig):
...@@ -2491,12 +2490,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): ...@@ -2491,12 +2490,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: BigBirdPegasusConfig): def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config) super().__init__(config)
...@@ -2669,7 +2663,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2669,7 +2663,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING, BIGBIRD_PEGASUS_START_DOCSTRING,
) )
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs): def __init__(self, config: BigBirdPegasusConfig, **kwargs):
...@@ -2799,7 +2792,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2799,7 +2792,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
BIGBIRD_PEGASUS_START_DOCSTRING, BIGBIRD_PEGASUS_START_DOCSTRING,
) )
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -2932,7 +2924,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel): ...@@ -2932,7 +2924,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -646,7 +646,6 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -646,7 +646,6 @@ class BioGptModel(BioGptPreTrainedModel):
"""BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING """BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING
) )
class BioGptForCausalLM(BioGptPreTrainedModel): class BioGptForCausalLM(BioGptPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"] _tied_weights_keys = ["output_projection.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -1102,7 +1102,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1102,7 +1102,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
BLENDERBOT_START_DOCSTRING, BLENDERBOT_START_DOCSTRING,
) )
class BlenderbotModel(BlenderbotPreTrainedModel): class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
...@@ -1244,14 +1243,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel): ...@@ -1244,14 +1243,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
) )
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
...@@ -1441,7 +1433,6 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): ...@@ -1441,7 +1433,6 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel): class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -1096,7 +1096,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1096,7 +1096,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
BLENDERBOT_SMALL_START_DOCSTRING, BLENDERBOT_SMALL_START_DOCSTRING,
) )
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
...@@ -1226,14 +1225,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): ...@@ -1226,14 +1225,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
) )
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
...@@ -1408,7 +1400,6 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel): ...@@ -1408,7 +1400,6 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -255,7 +255,9 @@ class BlipTextEmbeddings(nn.Module): ...@@ -255,7 +255,9 @@ class BlipTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward( def forward(
self, self,
...@@ -419,7 +421,6 @@ class BlipPreTrainedModel(PreTrainedModel): ...@@ -419,7 +421,6 @@ class BlipPreTrainedModel(PreTrainedModel):
config_class = BlipConfig config_class = BlipConfig
base_model_prefix = "blip" base_model_prefix = "blip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -927,7 +928,6 @@ class BlipModel(BlipPreTrainedModel): ...@@ -927,7 +928,6 @@ class BlipModel(BlipPreTrainedModel):
) )
class BlipForConditionalGeneration(BlipPreTrainedModel): class BlipForConditionalGeneration(BlipPreTrainedModel):
config_class = BlipConfig config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values" main_input_name = "pixel_values"
...@@ -1100,7 +1100,6 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): ...@@ -1100,7 +1100,6 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
) )
class BlipForQuestionAnswering(BlipPreTrainedModel): class BlipForQuestionAnswering(BlipPreTrainedModel):
config_class = BlipConfig config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig): def __init__(self, config: BlipConfig):
......
...@@ -56,7 +56,9 @@ class BlipTextEmbeddings(nn.Module): ...@@ -56,7 +56,9 @@ class BlipTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config self.config = config
...@@ -552,7 +554,6 @@ class BlipTextPreTrainedModel(PreTrainedModel): ...@@ -552,7 +554,6 @@ class BlipTextPreTrainedModel(PreTrainedModel):
config_class = BlipTextConfig config_class = BlipTextConfig
base_model_prefix = "bert" base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -808,9 +809,6 @@ class BlipTextModel(BlipTextPreTrainedModel): ...@@ -808,9 +809,6 @@ class BlipTextModel(BlipTextPreTrainedModel):
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class BlipTextLMHeadModel(BlipTextPreTrainedModel): class BlipTextLMHeadModel(BlipTextPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -273,12 +273,6 @@ class Blip2PreTrainedModel(PreTrainedModel): ...@@ -273,12 +273,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
config_class = Blip2Config config_class = Blip2Config
base_model_prefix = "blip" base_model_prefix = "blip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"] _no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
......
...@@ -471,12 +471,6 @@ class BloomBlock(nn.Module): ...@@ -471,12 +471,6 @@ class BloomBlock(nn.Module):
class BloomPreTrainedModel(PreTrainedModel): class BloomPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BloomConfig config_class = BloomConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
...@@ -826,7 +820,6 @@ class BloomModel(BloomPreTrainedModel): ...@@ -826,7 +820,6 @@ class BloomModel(BloomPreTrainedModel):
BLOOM_START_DOCSTRING, BLOOM_START_DOCSTRING,
) )
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: BloomConfig): def __init__(self, config: BloomConfig):
...@@ -995,8 +988,6 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -995,8 +988,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
BLOOM_START_DOCSTRING, BLOOM_START_DOCSTRING,
) )
class BloomForSequenceClassification(BloomPreTrainedModel): class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config: BloomConfig): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1123,8 +1114,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ...@@ -1123,8 +1114,6 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
BLOOM_START_DOCSTRING, BLOOM_START_DOCSTRING,
) )
class BloomForTokenClassification(BloomPreTrainedModel): class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config: BloomConfig): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1226,8 +1215,6 @@ class BloomForTokenClassification(BloomPreTrainedModel): ...@@ -1226,8 +1215,6 @@ class BloomForTokenClassification(BloomPreTrainedModel):
BLOOM_START_DOCSTRING, BLOOM_START_DOCSTRING,
) )
class BloomForQuestionAnswering(BloomPreTrainedModel): class BloomForQuestionAnswering(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config) self.transformer = BloomModel(config)
......
...@@ -280,7 +280,7 @@ class BridgeTowerVisionEmbeddings(nn.Module): ...@@ -280,7 +280,7 @@ class BridgeTowerVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -880,7 +880,9 @@ class BridgeTowerTextEmbeddings(nn.Module): ...@@ -880,7 +880,9 @@ class BridgeTowerTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -1038,8 +1040,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel): ...@@ -1038,8 +1040,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
config_class = BridgeTowerTextConfig config_class = BridgeTowerTextConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
......
...@@ -94,7 +94,9 @@ class CamembertEmbeddings(nn.Module): ...@@ -94,7 +94,9 @@ class CamembertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
...@@ -627,15 +629,6 @@ class CamembertPreTrainedModel(PreTrainedModel): ...@@ -627,15 +629,6 @@ class CamembertPreTrainedModel(PreTrainedModel):
if isinstance(module, CamembertEncoder): if isinstance(module, CamembertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
CAMEMBERT_INPUTS_DOCSTRING = r""" CAMEMBERT_INPUTS_DOCSTRING = r"""
Args: Args:
...@@ -762,7 +755,6 @@ class CamembertModel(CamembertPreTrainedModel): ...@@ -762,7 +755,6 @@ class CamembertModel(CamembertPreTrainedModel):
""" """
_keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = [] _no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
...@@ -935,9 +927,6 @@ class CamembertModel(CamembertPreTrainedModel): ...@@ -935,9 +927,6 @@ class CamembertModel(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForMaskedLM(CamembertPreTrainedModel): class CamembertForMaskedLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -952,9 +941,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel): ...@@ -952,9 +941,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
self.roberta = CamembertModel(config, add_pooling_layer=False) self.roberta = CamembertModel(config, add_pooling_layer=False)
self.lm_head = CamembertLMHead(config) self.lm_head = CamembertLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1042,8 +1028,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel): ...@@ -1042,8 +1028,6 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForSequenceClassification(CamembertPreTrainedModel): class CamembertForSequenceClassification(CamembertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1144,8 +1128,6 @@ class CamembertForSequenceClassification(CamembertPreTrainedModel): ...@@ -1144,8 +1128,6 @@ class CamembertForSequenceClassification(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT # Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForMultipleChoice(CamembertPreTrainedModel): class CamembertForMultipleChoice(CamembertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1241,9 +1223,6 @@ class CamembertForMultipleChoice(CamembertPreTrainedModel): ...@@ -1241,9 +1223,6 @@ class CamembertForMultipleChoice(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForTokenClassification(CamembertPreTrainedModel): class CamembertForTokenClassification(CamembertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1330,9 +1309,6 @@ class CamembertForTokenClassification(CamembertPreTrainedModel): ...@@ -1330,9 +1309,6 @@ class CamembertForTokenClassification(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering with Roberta->Camembert, ROBERTA->CAMEMBERT
class CamembertForQuestionAnswering(CamembertPreTrainedModel): class CamembertForQuestionAnswering(CamembertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1431,9 +1407,6 @@ class CamembertForQuestionAnswering(CamembertPreTrainedModel): ...@@ -1431,9 +1407,6 @@ class CamembertForQuestionAnswering(CamembertPreTrainedModel):
) )
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, roberta-base->camembert-base # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, roberta-base->camembert-base
class CamembertForCausalLM(CamembertPreTrainedModel): class CamembertForCausalLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1445,9 +1418,6 @@ class CamembertForCausalLM(CamembertPreTrainedModel): ...@@ -1445,9 +1418,6 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
self.roberta = CamembertModel(config, add_pooling_layer=False) self.roberta = CamembertModel(config, add_pooling_layer=False)
self.lm_head = CamembertLMHead(config) self.lm_head = CamembertLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -216,7 +216,9 @@ class CanineEmbeddings(nn.Module): ...@@ -216,7 +216,9 @@ class CanineEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int): def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int):
...@@ -900,7 +902,6 @@ class CaninePreTrainedModel(PreTrainedModel): ...@@ -900,7 +902,6 @@ class CaninePreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_canine load_tf_weights = load_tf_weights_in_canine
base_model_prefix = "canine" base_model_prefix = "canine"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment