Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
...@@ -602,7 +602,6 @@ Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/ ...@@ -602,7 +602,6 @@ Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/
class JukeboxVQVAE(PreTrainedModel): class JukeboxVQVAE(PreTrainedModel):
config_class = JukeboxVQVAEConfig config_class = JukeboxVQVAEConfig
base_model_prefix = "vqvae" base_model_prefix = "vqvae"
_keys_to_ignore_on_load_unexpected = [r"priors"]
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Embedding): # embed_tokens if isinstance(module, nn.Embedding): # embed_tokens
...@@ -1792,7 +1791,6 @@ class JukeboxPrior(PreTrainedModel): ...@@ -1792,7 +1791,6 @@ class JukeboxPrior(PreTrainedModel):
""" """
config_class = JukeboxPriorConfig config_class = JukeboxPriorConfig
_keys_to_ignore_on_load_unexpected = ["vqvae"]
def _init_weights(self, module): def _init_weights(self, module):
init_scale = self.config.init_scale init_scale = self.config.init_scale
...@@ -1832,7 +1830,6 @@ class JukeboxPrior(PreTrainedModel): ...@@ -1832,7 +1830,6 @@ class JukeboxPrior(PreTrainedModel):
self.level = level if level is not None else config.level self.level = level if level is not None else config.level
self.base_model_prefix = f"priors.{self.level}" self.base_model_prefix = f"priors.{self.level}"
self._keys_to_ignore_on_load_unexpected += [r"priors.[^%d]." % self.level]
self.n_ctx = config.n_ctx self.n_ctx = config.n_ctx
......
...@@ -68,7 +68,9 @@ class LayoutLMEmbeddings(nn.Module): ...@@ -68,7 +68,9 @@ class LayoutLMEmbeddings(nn.Module):
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward( def forward(
self, self,
...@@ -619,7 +621,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -619,7 +621,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm" base_model_prefix = "layoutlm"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -857,11 +858,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -857,11 +858,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING) @add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -77,7 +77,9 @@ class LayoutLMv2Embeddings(nn.Module): ...@@ -77,7 +77,9 @@ class LayoutLMv2Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def _calc_spatial_position_embeddings(self, bbox): def _calc_spatial_position_embeddings(self, bbox):
try: try:
...@@ -506,7 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): ...@@ -506,7 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
config_class = LayoutLMv2Config config_class = LayoutLMv2Config
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlmv2" base_model_prefix = "layoutlmv2"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -567,8 +568,11 @@ class LayoutLMv2VisualBackbone(nn.Module): ...@@ -567,8 +568,11 @@ class LayoutLMv2VisualBackbone(nn.Module):
self.register_buffer( self.register_buffer(
"pixel_mean", "pixel_mean",
torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1), torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
persistent=False,
)
self.register_buffer(
"pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
) )
self.register_buffer("pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1))
self.out_feature_key = "p2" self.out_feature_key = "p2"
if torch.are_deterministic_algorithms_enabled(): if torch.are_deterministic_algorithms_enabled():
logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`") logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
......
...@@ -245,7 +245,9 @@ class LayoutLMv3TextEmbeddings(nn.Module): ...@@ -245,7 +245,9 @@ class LayoutLMv3TextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
...@@ -750,8 +752,6 @@ class LayoutLMv3Output(nn.Module): ...@@ -750,8 +752,6 @@ class LayoutLMv3Output(nn.Module):
LAYOUTLMV3_START_DOCSTRING, LAYOUTLMV3_START_DOCSTRING,
) )
class LayoutLMv3Model(LayoutLMv3PreTrainedModel): class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -1038,9 +1038,6 @@ class LayoutLMv3ClassificationHead(nn.Module): ...@@ -1038,9 +1038,6 @@ class LayoutLMv3ClassificationHead(nn.Module):
LAYOUTLMV3_START_DOCSTRING, LAYOUTLMV3_START_DOCSTRING,
) )
class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel): class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1153,9 +1150,6 @@ class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel): ...@@ -1153,9 +1150,6 @@ class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
LAYOUTLMV3_START_DOCSTRING, LAYOUTLMV3_START_DOCSTRING,
) )
class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel): class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1286,8 +1280,6 @@ class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel): ...@@ -1286,8 +1280,6 @@ class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
LAYOUTLMV3_START_DOCSTRING, LAYOUTLMV3_START_DOCSTRING,
) )
class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel): class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -2209,7 +2209,6 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2209,7 +2209,6 @@ class LEDDecoder(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDModel(LEDPreTrainedModel): class LEDModel(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
...@@ -2335,14 +2334,7 @@ class LEDModel(LEDPreTrainedModel): ...@@ -2335,14 +2334,7 @@ class LEDModel(LEDPreTrainedModel):
) )
class LEDForConditionalGeneration(LEDPreTrainedModel): class LEDForConditionalGeneration(LEDPreTrainedModel):
base_model_prefix = "led" base_model_prefix = "led"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
...@@ -2530,7 +2522,6 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2530,7 +2522,6 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDForSequenceClassification(LEDPreTrainedModel): class LEDForSequenceClassification(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig, **kwargs): def __init__(self, config: LEDConfig, **kwargs):
...@@ -2667,7 +2658,6 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2667,7 +2658,6 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDForQuestionAnswering(LEDPreTrainedModel): class LEDForQuestionAnswering(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -195,7 +195,9 @@ class LevitAttention(nn.Module): ...@@ -195,7 +195,9 @@ class LevitAttention(nn.Module):
self.attention_bias_cache = {} self.attention_bias_cache = {}
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points)) self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
)
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
...@@ -271,7 +273,9 @@ class LevitAttentionSubsample(nn.Module): ...@@ -271,7 +273,9 @@ class LevitAttentionSubsample(nn.Module):
indices.append(attention_offsets[offset]) indices.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets))) self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points)) self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
)
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
......
...@@ -59,7 +59,9 @@ class LiltTextEmbeddings(nn.Module): ...@@ -59,7 +59,9 @@ class LiltTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# End copy # End copy
...@@ -610,15 +612,6 @@ class LiltPreTrainedModel(PreTrainedModel): ...@@ -610,15 +612,6 @@ class LiltPreTrainedModel(PreTrainedModel):
if isinstance(module, LiltEncoder): if isinstance(module, LiltEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
LILT_START_DOCSTRING = r""" LILT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
...@@ -697,8 +690,6 @@ LILT_INPUTS_DOCSTRING = r""" ...@@ -697,8 +690,6 @@ LILT_INPUTS_DOCSTRING = r"""
LILT_START_DOCSTRING, LILT_START_DOCSTRING,
) )
class LiltModel(LiltPreTrainedModel): class LiltModel(LiltPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -847,8 +838,6 @@ class LiltModel(LiltPreTrainedModel): ...@@ -847,8 +838,6 @@ class LiltModel(LiltPreTrainedModel):
LILT_START_DOCSTRING, LILT_START_DOCSTRING,
) )
class LiltForSequenceClassification(LiltPreTrainedModel): class LiltForSequenceClassification(LiltPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Lilt, roberta->lilt # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -967,9 +956,6 @@ class LiltForSequenceClassification(LiltPreTrainedModel): ...@@ -967,9 +956,6 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
LILT_START_DOCSTRING, LILT_START_DOCSTRING,
) )
class LiltForTokenClassification(LiltPreTrainedModel): class LiltForTokenClassification(LiltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Lilt, roberta->lilt # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1096,9 +1082,6 @@ class LiltClassificationHead(nn.Module): ...@@ -1096,9 +1082,6 @@ class LiltClassificationHead(nn.Module):
LILT_START_DOCSTRING, LILT_START_DOCSTRING,
) )
class LiltForQuestionAnswering(LiltPreTrainedModel): class LiltForQuestionAnswering(LiltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Lilt, roberta->lilt # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -344,7 +344,6 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -344,7 +344,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"] _no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
...@@ -784,8 +783,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -784,8 +783,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
LLAMA_START_DOCSTRING, LLAMA_START_DOCSTRING,
) )
class LlamaForSequenceClassification(LlamaPreTrainedModel): class LlamaForSequenceClassification(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -1421,7 +1421,6 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1421,7 +1421,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"] _no_split_modules = ["LongformerSelfAttention"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1770,8 +1769,6 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1770,8 +1769,6 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING) @add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"] _tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config): def __init__(self, config):
...@@ -1886,8 +1883,6 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1886,8 +1883,6 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForSequenceClassification(LongformerPreTrainedModel): class LongformerForSequenceClassification(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -2015,8 +2010,6 @@ class LongformerClassificationHead(nn.Module): ...@@ -2015,8 +2010,6 @@ class LongformerClassificationHead(nn.Module):
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForQuestionAnswering(LongformerPreTrainedModel): class LongformerForQuestionAnswering(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -2154,8 +2147,6 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -2154,8 +2147,6 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
LONGFORMER_START_DOCSTRING, LONGFORMER_START_DOCSTRING,
) )
class LongformerForTokenClassification(LongformerPreTrainedModel): class LongformerForTokenClassification(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -1763,10 +1763,6 @@ num_heads)`. ...@@ -1763,10 +1763,6 @@ num_heads)`.
LONGT5_START_DOCSTRING, LONGT5_START_DOCSTRING,
) )
class LongT5Model(LongT5PreTrainedModel): class LongT5Model(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
...@@ -1917,11 +1913,6 @@ class LongT5Model(LongT5PreTrainedModel): ...@@ -1917,11 +1913,6 @@ class LongT5Model(LongT5PreTrainedModel):
@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
class LongT5ForConditionalGeneration(LongT5PreTrainedModel): class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
...@@ -2160,7 +2151,6 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -2160,7 +2151,6 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
LONGT5_START_DOCSTRING, LONGT5_START_DOCSTRING,
) )
class LongT5EncoderModel(LongT5PreTrainedModel): class LongT5EncoderModel(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
......
...@@ -1022,8 +1022,6 @@ LUKE_INPUTS_DOCSTRING = r""" ...@@ -1022,8 +1022,6 @@ LUKE_INPUTS_DOCSTRING = r"""
LUKE_START_DOCSTRING, LUKE_START_DOCSTRING,
) )
class LukeModel(LukePreTrainedModel): class LukeModel(LukePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config: LukeConfig, add_pooling_layer: bool = True): def __init__(self, config: LukeConfig, add_pooling_layer: bool = True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -1278,17 +1276,6 @@ class LukeLMHead(nn.Module): ...@@ -1278,17 +1276,6 @@ class LukeLMHead(nn.Module):
LUKE_START_DOCSTRING, LUKE_START_DOCSTRING,
) )
class LukeForMaskedLM(LukePreTrainedModel): class LukeForMaskedLM(LukePreTrainedModel):
_keys_to_ignore_on_save = [
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"] _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -1018,7 +1018,6 @@ class LxmertModel(LxmertPreTrainedModel): ...@@ -1018,7 +1018,6 @@ class LxmertModel(LxmertPreTrainedModel):
LXMERT_START_DOCSTRING, LXMERT_START_DOCSTRING,
) )
class LxmertForPreTraining(LxmertPreTrainedModel): class LxmertForPreTraining(LxmertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -131,7 +131,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): ...@@ -131,7 +131,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param # in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights) self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod @staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
...@@ -1137,14 +1137,6 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1137,14 +1137,6 @@ class M2M100Decoder(M2M100PreTrainedModel):
M2M_100_START_DOCSTRING, M2M_100_START_DOCSTRING,
) )
class M2M100Model(M2M100PreTrainedModel): class M2M100Model(M2M100PreTrainedModel):
_keys_to_ignore_on_load_missing = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"encoder.embed_positions.weights",
"encoder.embed_positions.bias",
"decoder.embed_positions.weights",
"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
...@@ -1258,17 +1250,6 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1258,17 +1250,6 @@ class M2M100Model(M2M100PreTrainedModel):
) )
class M2M100ForConditionalGeneration(M2M100PreTrainedModel): class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"encoder.embed_positions.weights",
r"encoder.embed_positions.bias",
r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
......
...@@ -1103,7 +1103,6 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1103,7 +1103,6 @@ class MarianDecoder(MarianPreTrainedModel):
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING "The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
) )
class MarianModel(MarianPreTrainedModel): class MarianModel(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
...@@ -1292,13 +1291,9 @@ class MarianModel(MarianPreTrainedModel): ...@@ -1292,13 +1291,9 @@ class MarianModel(MarianPreTrainedModel):
class MarianMTModel(MarianPreTrainedModel): class MarianMTModel(MarianPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", "final_logits_bias",
r"encoder.version", "encoder.embed_positions.weight",
r"decoder.version", "decoder.embed_positions.weight",
r"lm_head.weight",
r"embed_positions",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
] ]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"] _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
...@@ -1561,7 +1556,6 @@ class MarianDecoderWrapper(MarianPreTrainedModel): ...@@ -1561,7 +1556,6 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel): class MarianForCausalLM(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -143,7 +143,9 @@ class MarkupLMEmbeddings(nn.Module): ...@@ -143,7 +143,9 @@ class MarkupLMEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding( self.position_embeddings = nn.Embedding(
...@@ -713,7 +715,6 @@ class MarkupLMPreTrainedModel(PreTrainedModel): ...@@ -713,7 +715,6 @@ class MarkupLMPreTrainedModel(PreTrainedModel):
config_class = MarkupLMConfig config_class = MarkupLMConfig
pretrained_model_archive_map = MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "markuplm" base_model_prefix = "markuplm"
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM
def _init_weights(self, module): def _init_weights(self, module):
...@@ -971,8 +972,6 @@ class MarkupLMModel(MarkupLMPreTrainedModel): ...@@ -971,8 +972,6 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
MARKUPLM_START_DOCSTRING, MARKUPLM_START_DOCSTRING,
) )
class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel): class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1156,7 +1156,6 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1156,7 +1156,6 @@ class MBartDecoder(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartModel(MBartPreTrainedModel): class MBartModel(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
...@@ -1277,14 +1276,7 @@ class MBartModel(MBartPreTrainedModel): ...@@ -1277,14 +1276,7 @@ class MBartModel(MBartPreTrainedModel):
) )
class MBartForConditionalGeneration(MBartPreTrainedModel): class MBartForConditionalGeneration(MBartPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
...@@ -1452,7 +1444,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1452,7 +1444,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartForSequenceClassification(MBartPreTrainedModel): class MBartForSequenceClassification(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs): def __init__(self, config: MBartConfig, **kwargs):
...@@ -1582,7 +1573,6 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1582,7 +1573,6 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartForQuestionAnswering(MBartPreTrainedModel): class MBartForQuestionAnswering(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
...@@ -1716,7 +1706,6 @@ class MBartDecoderWrapper(MBartPreTrainedModel): ...@@ -1716,7 +1706,6 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel): class MBartForCausalLM(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
...@@ -149,7 +149,9 @@ class MCTCTEmbeddings(nn.Module): ...@@ -149,7 +149,9 @@ class MCTCTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
...@@ -443,7 +445,6 @@ class MCTCTPreTrainedModel(PreTrainedModel): ...@@ -443,7 +445,6 @@ class MCTCTPreTrainedModel(PreTrainedModel):
config_class = MCTCTConfig config_class = MCTCTConfig
base_model_prefix = "mctct" base_model_prefix = "mctct"
main_input_name = "input_features" main_input_name = "input_features"
_keys_to_ignore_on_load_missing = ["position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -1387,15 +1387,6 @@ class MegaPreTrainedModel(PreTrainedModel): ...@@ -1387,15 +1387,6 @@ class MegaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
MEGA_START_DOCSTRING = r""" MEGA_START_DOCSTRING = r"""
...@@ -1474,8 +1465,6 @@ class MegaModel(MegaPreTrainedModel): ...@@ -1474,8 +1465,6 @@ class MegaModel(MegaPreTrainedModel):
""" """
_keys_to_ignore_on_load_missing = []
def __init__(self, config: MegaConfig, add_pooling_layer=True): def __init__(self, config: MegaConfig, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -1656,9 +1645,6 @@ class MegaModel(MegaPreTrainedModel): ...@@ -1656,9 +1645,6 @@ class MegaModel(MegaPreTrainedModel):
"""MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING """MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING
) )
class MegaForCausalLM(MegaPreTrainedModel): class MegaForCausalLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MegaConfig): def __init__(self, config: MegaConfig):
...@@ -1678,9 +1664,6 @@ class MegaForCausalLM(MegaPreTrainedModel): ...@@ -1678,9 +1664,6 @@ class MegaForCausalLM(MegaPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1821,9 +1804,6 @@ class MegaForCausalLM(MegaPreTrainedModel): ...@@ -1821,9 +1804,6 @@ class MegaForCausalLM(MegaPreTrainedModel):
@add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING) @add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING)
class MegaForMaskedLM(MegaPreTrainedModel): class MegaForMaskedLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_missing = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["mlm_head.weight"] _tied_weights_keys = ["mlm_head.weight"]
def __init__(self, config: MegaConfig): def __init__(self, config: MegaConfig):
...@@ -1845,9 +1825,6 @@ class MegaForMaskedLM(MegaPreTrainedModel): ...@@ -1845,9 +1825,6 @@ class MegaForMaskedLM(MegaPreTrainedModel):
self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size) self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.dropout = nn.Dropout(config.dropout_prob) self.dropout = nn.Dropout(config.dropout_prob)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["mlm_head.weight"])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1931,8 +1908,6 @@ class MegaForMaskedLM(MegaPreTrainedModel): ...@@ -1931,8 +1908,6 @@ class MegaForMaskedLM(MegaPreTrainedModel):
MEGA_START_DOCSTRING, MEGA_START_DOCSTRING,
) )
class MegaForSequenceClassification(MegaPreTrainedModel): class MegaForSequenceClassification(MegaPreTrainedModel):
_keys_to_ignore_on_load_missing = []
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -2024,8 +1999,6 @@ class MegaForSequenceClassification(MegaPreTrainedModel): ...@@ -2024,8 +1999,6 @@ class MegaForSequenceClassification(MegaPreTrainedModel):
MEGA_START_DOCSTRING, MEGA_START_DOCSTRING,
) )
class MegaForMultipleChoice(MegaPreTrainedModel): class MegaForMultipleChoice(MegaPreTrainedModel):
_keys_to_ignore_on_load_missing = []
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -2111,9 +2084,6 @@ class MegaForMultipleChoice(MegaPreTrainedModel): ...@@ -2111,9 +2084,6 @@ class MegaForMultipleChoice(MegaPreTrainedModel):
MEGA_START_DOCSTRING, MEGA_START_DOCSTRING,
) )
class MegaForTokenClassification(MegaPreTrainedModel): class MegaForTokenClassification(MegaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = []
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -2214,9 +2184,6 @@ class MegaClassificationHead(nn.Module): ...@@ -2214,9 +2184,6 @@ class MegaClassificationHead(nn.Module):
MEGA_START_DOCSTRING, MEGA_START_DOCSTRING,
) )
class MegaForQuestionAnswering(MegaPreTrainedModel): class MegaForQuestionAnswering(MegaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = []
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -149,7 +149,9 @@ class MegatronBertEmbeddings(nn.Module): ...@@ -149,7 +149,9 @@ class MegatronBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward( def forward(
...@@ -713,7 +715,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel): ...@@ -713,7 +715,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_megatron_bert load_tf_weights = load_tf_weights_in_megatron_bert
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -1014,7 +1015,6 @@ class MegatronBertModel(MegatronBertPreTrainedModel): ...@@ -1014,7 +1015,6 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForPreTraining(MegatronBertPreTrainedModel): class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"] _tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True): def __init__(self, config, add_binary_head=True):
...@@ -1121,8 +1121,6 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel): ...@@ -1121,8 +1121,6 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForCausalLM(MegatronBertPreTrainedModel): class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"] _tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
...@@ -1267,8 +1265,6 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): ...@@ -1267,8 +1265,6 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
@add_start_docstrings("""MegatronBert Model with a `language modeling` head on top.""", MEGATRON_BERT_START_DOCSTRING) @add_start_docstrings("""MegatronBert Model with a `language modeling` head on top.""", MEGATRON_BERT_START_DOCSTRING)
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"] _tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
...@@ -1376,8 +1372,6 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): ...@@ -1376,8 +1372,6 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel): class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"predictions"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1672,8 +1666,6 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel): ...@@ -1672,8 +1666,6 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForTokenClassification(MegatronBertPreTrainedModel): class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1752,8 +1744,6 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel): ...@@ -1752,8 +1744,6 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel): class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
...@@ -191,7 +191,9 @@ class MobileBertEmbeddings(nn.Module): ...@@ -191,7 +191,9 @@ class MobileBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward( def forward(
self, self,
...@@ -686,7 +688,6 @@ class MobileBertPreTrainedModel(PreTrainedModel): ...@@ -686,7 +688,6 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert" base_model_prefix = "mobilebert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -923,11 +924,6 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -923,11 +924,6 @@ class MobileBertModel(MobileBertPreTrainedModel):
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
) )
class MobileBertForPreTraining(MobileBertPreTrainedModel): class MobileBertForPreTraining(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1036,12 +1032,6 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -1036,12 +1032,6 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING)
class MobileBertForMaskedLM(MobileBertPreTrainedModel): class MobileBertForMaskedLM(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
...@@ -1350,8 +1340,6 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1350,8 +1340,6 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
) )
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -1553,8 +1541,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ...@@ -1553,8 +1541,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
) )
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
class MobileBertForTokenClassification(MobileBertPreTrainedModel): class MobileBertForTokenClassification(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment