Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -602,7 +602,6 @@ Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/
class JukeboxVQVAE(PreTrainedModel):
config_class = JukeboxVQVAEConfig
base_model_prefix = "vqvae"
_keys_to_ignore_on_load_unexpected = [r"priors"]
def _init_weights(self, module):
if isinstance(module, nn.Embedding): # embed_tokens
......@@ -1792,7 +1791,6 @@ class JukeboxPrior(PreTrainedModel):
"""
config_class = JukeboxPriorConfig
_keys_to_ignore_on_load_unexpected = ["vqvae"]
def _init_weights(self, module):
init_scale = self.config.init_scale
......@@ -1832,7 +1830,6 @@ class JukeboxPrior(PreTrainedModel):
self.level = level if level is not None else config.level
self.base_model_prefix = f"priors.{self.level}"
self._keys_to_ignore_on_load_unexpected += [r"priors.[^%d]." % self.level]
self.n_ctx = config.n_ctx
......
......@@ -68,7 +68,9 @@ class LayoutLMEmbeddings(nn.Module):
self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -619,7 +621,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -857,11 +858,6 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......
......@@ -77,7 +77,9 @@ class LayoutLMv2Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def _calc_spatial_position_embeddings(self, bbox):
try:
......@@ -506,7 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
config_class = LayoutLMv2Config
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlmv2"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -567,8 +568,11 @@ class LayoutLMv2VisualBackbone(nn.Module):
self.register_buffer(
"pixel_mean",
torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
persistent=False,
)
self.register_buffer(
"pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
)
self.register_buffer("pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1))
self.out_feature_key = "p2"
if torch.are_deterministic_algorithms_enabled():
logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
......
......@@ -245,7 +245,9 @@ class LayoutLMv3TextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
......@@ -750,8 +752,6 @@ class LayoutLMv3Output(nn.Module):
LAYOUTLMV3_START_DOCSTRING,
)
class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.config = config
......@@ -1038,9 +1038,6 @@ class LayoutLMv3ClassificationHead(nn.Module):
LAYOUTLMV3_START_DOCSTRING,
)
class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1153,9 +1150,6 @@ class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
LAYOUTLMV3_START_DOCSTRING,
)
class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1286,8 +1280,6 @@ class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
LAYOUTLMV3_START_DOCSTRING,
)
class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -2209,7 +2209,6 @@ class LEDDecoder(LEDPreTrainedModel):
LED_START_DOCSTRING,
)
class LEDModel(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig):
......@@ -2335,14 +2334,7 @@ class LEDModel(LEDPreTrainedModel):
)
class LEDForConditionalGeneration(LEDPreTrainedModel):
base_model_prefix = "led"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LEDConfig):
......@@ -2530,7 +2522,6 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
LED_START_DOCSTRING,
)
class LEDForSequenceClassification(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig, **kwargs):
......@@ -2667,7 +2658,6 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
LED_START_DOCSTRING,
)
class LEDForQuestionAnswering(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config):
......
......@@ -195,7 +195,9 @@ class LevitAttention(nn.Module):
self.attention_bias_cache = {}
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points))
self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
)
@torch.no_grad()
def train(self, mode=True):
......@@ -271,7 +273,9 @@ class LevitAttentionSubsample(nn.Module):
indices.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points))
self.register_buffer(
"attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
)
@torch.no_grad()
def train(self, mode=True):
......
......@@ -59,7 +59,9 @@ class LiltTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# End copy
......@@ -610,15 +612,6 @@ class LiltPreTrainedModel(PreTrainedModel):
if isinstance(module, LiltEncoder):
module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
LILT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
......@@ -697,8 +690,6 @@ LILT_INPUTS_DOCSTRING = r"""
LILT_START_DOCSTRING,
)
class LiltModel(LiltPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -847,8 +838,6 @@ class LiltModel(LiltPreTrainedModel):
LILT_START_DOCSTRING,
)
class LiltForSequenceClassification(LiltPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config):
super().__init__(config)
......@@ -967,9 +956,6 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
LILT_START_DOCSTRING,
)
class LiltForTokenClassification(LiltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config):
super().__init__(config)
......@@ -1096,9 +1082,6 @@ class LiltClassificationHead(nn.Module):
LILT_START_DOCSTRING,
)
class LiltForQuestionAnswering(LiltPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Lilt, roberta->lilt
def __init__(self, config):
super().__init__(config)
......
......@@ -344,7 +344,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -784,8 +783,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
LLAMA_START_DOCSTRING,
)
class LlamaForSequenceClassification(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -1421,7 +1421,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"]
def _init_weights(self, module):
......@@ -1770,8 +1769,6 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config):
......@@ -1886,8 +1883,6 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
LONGFORMER_START_DOCSTRING,
)
class LongformerForSequenceClassification(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -2015,8 +2010,6 @@ class LongformerClassificationHead(nn.Module):
LONGFORMER_START_DOCSTRING,
)
class LongformerForQuestionAnswering(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -2154,8 +2147,6 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
LONGFORMER_START_DOCSTRING,
)
class LongformerForTokenClassification(LongformerPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -1763,10 +1763,6 @@ num_heads)`.
LONGT5_START_DOCSTRING,
)
class LongT5Model(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
......@@ -1917,11 +1913,6 @@ class LongT5Model(LongT5PreTrainedModel):
@add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
......@@ -2160,7 +2151,6 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
LONGT5_START_DOCSTRING,
)
class LongT5EncoderModel(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config):
......
......@@ -1022,8 +1022,6 @@ LUKE_INPUTS_DOCSTRING = r"""
LUKE_START_DOCSTRING,
)
class LukeModel(LukePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config: LukeConfig, add_pooling_layer: bool = True):
super().__init__(config)
self.config = config
......@@ -1278,17 +1276,6 @@ class LukeLMHead(nn.Module):
LUKE_START_DOCSTRING,
)
class LukeForMaskedLM(LukePreTrainedModel):
_keys_to_ignore_on_save = [
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"]
def __init__(self, config):
......
......@@ -1018,7 +1018,6 @@ class LxmertModel(LxmertPreTrainedModel):
LXMERT_START_DOCSTRING,
)
class LxmertForPreTraining(LxmertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
......
......@@ -131,7 +131,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
......@@ -1137,14 +1137,6 @@ class M2M100Decoder(M2M100PreTrainedModel):
M2M_100_START_DOCSTRING,
)
class M2M100Model(M2M100PreTrainedModel):
_keys_to_ignore_on_load_missing = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"encoder.embed_positions.weights",
"encoder.embed_positions.bias",
"decoder.embed_positions.weights",
"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config):
......@@ -1258,17 +1250,6 @@ class M2M100Model(M2M100PreTrainedModel):
)
class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"encoder.embed_positions.weights",
r"encoder.embed_positions.bias",
r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: M2M100Config):
......
......@@ -1103,7 +1103,6 @@ class MarianDecoder(MarianPreTrainedModel):
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
)
class MarianModel(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MarianConfig):
......@@ -1292,13 +1291,9 @@ class MarianModel(MarianPreTrainedModel):
class MarianMTModel(MarianPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"embed_positions",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"final_logits_bias",
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
......@@ -1561,7 +1556,6 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -143,7 +143,9 @@ class MarkupLMEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
......@@ -713,7 +715,6 @@ class MarkupLMPreTrainedModel(PreTrainedModel):
config_class = MarkupLMConfig
pretrained_model_archive_map = MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "markuplm"
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM
def _init_weights(self, module):
......@@ -971,8 +972,6 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
MARKUPLM_START_DOCSTRING,
)
class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM
def __init__(self, config):
super().__init__(config)
......
......@@ -1156,7 +1156,6 @@ class MBartDecoder(MBartPreTrainedModel):
MBART_START_DOCSTRING,
)
class MBartModel(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig):
......@@ -1277,14 +1276,7 @@ class MBartModel(MBartPreTrainedModel):
)
class MBartForConditionalGeneration(MBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MBartConfig):
......@@ -1452,7 +1444,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
MBART_START_DOCSTRING,
)
class MBartForSequenceClassification(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs):
......@@ -1582,7 +1573,6 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
MBART_START_DOCSTRING,
)
class MBartForQuestionAnswering(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config):
......@@ -1716,7 +1706,6 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -149,7 +149,9 @@ class MCTCTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
......@@ -443,7 +445,6 @@ class MCTCTPreTrainedModel(PreTrainedModel):
config_class = MCTCTConfig
base_model_prefix = "mctct"
main_input_name = "input_features"
_keys_to_ignore_on_load_missing = ["position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -1387,15 +1387,6 @@ class MegaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list"""
if not config.tie_word_embeddings:
# must make a new list, or the class variable gets modified!
self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
self._keys_to_ignore_on_load_missing = [
k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
]
MEGA_START_DOCSTRING = r"""
......@@ -1474,8 +1465,6 @@ class MegaModel(MegaPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = []
def __init__(self, config: MegaConfig, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -1656,9 +1645,6 @@ class MegaModel(MegaPreTrainedModel):
"""MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING
)
class MegaForCausalLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MegaConfig):
......@@ -1678,9 +1664,6 @@ class MegaForCausalLM(MegaPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1821,9 +1804,6 @@ class MegaForCausalLM(MegaPreTrainedModel):
@add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING)
class MegaForMaskedLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_missing = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["mlm_head.weight"]
def __init__(self, config: MegaConfig):
......@@ -1845,9 +1825,6 @@ class MegaForMaskedLM(MegaPreTrainedModel):
self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.dropout = nn.Dropout(config.dropout_prob)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["mlm_head.weight"])
# Initialize weights and apply final processing
self.post_init()
......@@ -1931,8 +1908,6 @@ class MegaForMaskedLM(MegaPreTrainedModel):
MEGA_START_DOCSTRING,
)
class MegaForSequenceClassification(MegaPreTrainedModel):
_keys_to_ignore_on_load_missing = []
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -2024,8 +1999,6 @@ class MegaForSequenceClassification(MegaPreTrainedModel):
MEGA_START_DOCSTRING,
)
class MegaForMultipleChoice(MegaPreTrainedModel):
_keys_to_ignore_on_load_missing = []
def __init__(self, config):
super().__init__(config)
......@@ -2111,9 +2084,6 @@ class MegaForMultipleChoice(MegaPreTrainedModel):
MEGA_START_DOCSTRING,
)
class MegaForTokenClassification(MegaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = []
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -2214,9 +2184,6 @@ class MegaClassificationHead(nn.Module):
MEGA_START_DOCSTRING,
)
class MegaForQuestionAnswering(MegaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = []
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -149,7 +149,9 @@ class MegatronBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(
......@@ -713,7 +715,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_megatron_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1014,7 +1015,6 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True):
......@@ -1121,8 +1121,6 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......@@ -1267,8 +1265,6 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
@add_start_docstrings("""MegatronBert Model with a `language modeling` head on top.""", MEGATRON_BERT_START_DOCSTRING)
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......@@ -1376,8 +1372,6 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForNextSentencePrediction(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"predictions"]
def __init__(self, config):
super().__init__(config)
......@@ -1672,8 +1666,6 @@ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1752,8 +1744,6 @@ class MegatronBertForTokenClassification(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING,
)
class MegatronBertForQuestionAnswering(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -191,7 +191,9 @@ class MobileBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -686,7 +688,6 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -923,11 +924,6 @@ class MobileBertModel(MobileBertPreTrainedModel):
MOBILEBERT_START_DOCSTRING,
)
class MobileBertForPreTraining(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1036,12 +1032,6 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING)
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1350,8 +1340,6 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
)
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1553,8 +1541,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
)
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment