Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -83,7 +83,9 @@ class MPNetEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, **kwargs):
if position_ids is None:
......@@ -479,8 +481,6 @@ MPNET_INPUTS_DOCSTRING = r"""
MPNET_START_DOCSTRING,
)
class MPNetModel(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -570,8 +570,6 @@ class MPNetModel(MPNetPreTrainedModel):
class MPNetForMaskedLM(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config):
......@@ -679,8 +677,6 @@ class MPNetLMHead(nn.Module):
MPNET_START_DOCSTRING,
)
class MPNetForSequenceClassification(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -773,8 +769,6 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
MPNET_START_DOCSTRING,
)
class MPNetForMultipleChoice(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -863,9 +857,6 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
MPNET_START_DOCSTRING,
)
class MPNetForTokenClassification(MPNetPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -962,9 +953,6 @@ class MPNetClassificationHead(nn.Module):
MPNET_START_DOCSTRING,
)
class MPNetForQuestionAnswering(MPNetPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1316,18 +1316,8 @@ class MT5Model(MT5PreTrainedModel):
```"""
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_save = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_load_missing = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
......@@ -1552,15 +1542,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
......@@ -1897,13 +1879,6 @@ class MT5EncoderModel(MT5PreTrainedModel):
model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5
......@@ -2029,14 +2004,7 @@ class MT5EncoderModel(MT5PreTrainedModel):
MT5_START_DOCSTRING,
)
class MT5ForQuestionAnswering(MT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5
......
......@@ -551,7 +551,6 @@ class MvpPreTrainedModel(PreTrainedModel):
config_class = MvpConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
def _init_weights(self, module):
std = self.config.init_std
......@@ -1300,8 +1299,7 @@ class MvpDecoder(MvpPreTrainedModel):
MVP_START_DOCSTRING,
)
class MvpModel(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_keys_to_ignore_on_load_unexpected = ["final_logits_bias"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig):
......@@ -1438,7 +1436,6 @@ class MvpModel(MvpPreTrainedModel):
"The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING
)
class MvpForConditionalGeneration(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MvpConfig):
......@@ -1611,8 +1608,6 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
MVP_START_DOCSTRING,
)
class MvpForSequenceClassification(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig, **kwargs):
......@@ -1740,8 +1735,6 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
MVP_START_DOCSTRING,
)
class MvpForQuestionAnswering(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
......@@ -1873,7 +1866,6 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
class MvpForCausalLM(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -163,7 +163,7 @@ class NezhaRelativePositionsEncoding(nn.Module):
my_shape = list(final_mat.size())
my_shape.append(depth)
positions_encoding = positions_encoding.view(my_shape)
self.register_buffer("positions_encoding", positions_encoding)
self.register_buffer("positions_encoding", positions_encoding, persistent=False)
def forward(self, length):
return self.positions_encoding[:length, :length, :]
......@@ -735,7 +735,6 @@ class NezhaPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_nezha
base_model_prefix = "nezha"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"positions_encoding"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1037,7 +1036,6 @@ class NezhaModel(NezhaPreTrainedModel):
NEZHA_START_DOCSTRING,
)
class NezhaForPreTraining(NezhaPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......@@ -1140,8 +1138,6 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
@add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING)
class NezhaForMaskedLM(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......@@ -1542,8 +1538,6 @@ class NezhaForMultipleChoice(NezhaPreTrainedModel):
NEZHA_START_DOCSTRING,
)
class NezhaForTokenClassification(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1623,8 +1617,6 @@ class NezhaForTokenClassification(NezhaPreTrainedModel):
NEZHA_START_DOCSTRING,
)
class NezhaForQuestionAnswering(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -183,7 +183,7 @@ class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.register_buffer("weights", emb_weights)
self.register_buffer("weights", emb_weights, persistent=False)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
......@@ -1500,14 +1500,6 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
NLLB_MOE_START_DOCSTRING,
)
class NllbMoeModel(NllbMoePreTrainedModel):
_keys_to_ignore_on_load_missing = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
"encoder.embed_positions.weights",
"encoder.embed_positions.bias",
"decoder.embed_positions.weights",
"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: NllbMoeConfig):
......@@ -1641,17 +1633,6 @@ class NllbMoeModel(NllbMoePreTrainedModel):
)
class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"encoder.embed_positions.weights",
r"encoder.embed_positions.bias",
r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: NllbMoeConfig):
......
......@@ -64,7 +64,9 @@ class NystromformerEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids",
......@@ -458,7 +460,6 @@ class NystromformerPreTrainedModel(PreTrainedModel):
config_class = NystromformerConfig
base_model_prefix = "nystromformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -658,7 +659,6 @@ class NystromformerModel(NystromformerPreTrainedModel):
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
class NystromformerForMaskedLM(NystromformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......
......@@ -368,7 +368,6 @@ class OpenLlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OpenLlamaDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.initializer_range
......@@ -825,8 +824,6 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->OPEN_LLAMA,Llama->OpenLlama
class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -141,7 +141,9 @@ class Attention(nn.Module):
if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
"bias",
torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
persistent=False,
)
self.n_head = config.n_head
self.split_size = n_state
......@@ -274,7 +276,6 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -407,7 +408,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
self.register_buffer("position_ids", torch.arange(config.n_positions))
self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
# Initialize weights and apply final processing
self.post_init()
......@@ -529,7 +530,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -621,7 +621,6 @@ input sequence).
OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -399,7 +399,6 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.init_std
......@@ -817,7 +816,6 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -1025,8 +1023,6 @@ class OPTForCausalLM(OPTPreTrainedModel):
OPT_START_DOCSTRING,
)
class OPTForSequenceClassification(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: OPTConfig):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1147,8 +1143,6 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
OPT_START_DOCSTRING,
)
class OPTForQuestionAnswering(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config: OPTConfig):
super().__init__(config)
self.model = OPTModel(config)
......
......@@ -304,7 +304,7 @@ class OwlViTVisionEmbeddings(nn.Module):
self.num_patches = (config.image_size // config.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......@@ -325,7 +325,9 @@ class OwlViTTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -530,7 +532,6 @@ class OwlViTPreTrainedModel(PreTrainedModel):
config_class = OwlViTConfig
base_model_prefix = "owlvit"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = ["OwlViTEncoderLayer"]
def _init_weights(self, module):
......
......@@ -1156,7 +1156,6 @@ class PegasusDecoder(PegasusPreTrainedModel):
PEGASUS_START_DOCSTRING,
)
class PegasusModel(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusConfig):
......@@ -1309,15 +1308,7 @@ class PegasusModel(PegasusPreTrainedModel):
)
class PegasusForConditionalGeneration(PegasusPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"embed_positions.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusConfig):
......@@ -1518,7 +1509,6 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
class PegasusForCausalLM(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -1391,7 +1391,6 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
PEGASUS_X_START_DOCSTRING,
)
class PegasusXModel(PegasusXPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusXConfig):
......@@ -1536,14 +1535,6 @@ class PegasusXModel(PegasusXPreTrainedModel):
@add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING)
class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
r"embed_positions.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusXConfig):
......
......@@ -1597,14 +1597,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
config_class = Pix2StructConfig
main_input_name = "flattened_patches"
_keys_to_ignore_on_load_missing = [
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder.layer.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["decoder.lm_head.weight"]
def __init__(self, config: Pix2StructConfig):
......
......@@ -1132,7 +1132,6 @@ class PLBartDecoder(PLBartPreTrainedModel):
PLBART_START_DOCSTRING,
)
class PLBartModel(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig):
......@@ -1251,14 +1250,7 @@ class PLBartModel(PLBartPreTrainedModel):
)
class PLBartForConditionalGeneration(PLBartPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PLBartConfig):
......@@ -1423,7 +1415,6 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
PLBART_START_DOCSTRING,
)
class PLBartForSequenceClassification(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig, **kwargs):
......@@ -1562,7 +1553,6 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......
......@@ -1744,7 +1744,6 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: ProphetNetConfig):
......@@ -1874,11 +1873,6 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: ProphetNetConfig):
......@@ -2091,7 +2085,6 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ProphetNetConfig):
......
......@@ -164,7 +164,9 @@ class QDQBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -738,7 +740,6 @@ class QDQBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_qdqbert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1012,8 +1013,6 @@ class QDQBertModel(QDQBertPreTrainedModel):
"""QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING
)
class QDQBertLMHeadModel(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config):
......@@ -1166,8 +1165,6 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING)
class QDQBertForMaskedLM(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config):
......@@ -1570,8 +1567,6 @@ class QDQBertForMultipleChoice(QDQBertPreTrainedModel):
QDQBERT_START_DOCSTRING,
)
class QDQBertForTokenClassification(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1650,8 +1645,6 @@ class QDQBertForTokenClassification(QDQBertPreTrainedModel):
QDQBERT_START_DOCSTRING,
)
class QDQBertForQuestionAnswering(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -231,7 +231,6 @@ class RagPreTrainedModel(PreTrainedModel):
"""
config_class = RagConfig
base_model_prefix = "rag"
_keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod
def from_pretrained(cls, *args, **kwargs):
......
......@@ -178,7 +178,9 @@ class RealmEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -968,7 +970,6 @@ class RealmPreTrainedModel(PreTrainedModel):
config_class = RealmConfig
load_tf_weights = load_tf_weights_in_realm
base_model_prefix = "realm"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1147,7 +1148,6 @@ class RealmBertModel(RealmPreTrainedModel):
REALM_START_DOCSTRING,
)
class RealmEmbedder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config):
......@@ -1378,7 +1378,6 @@ class RealmScorer(RealmPreTrainedModel):
REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
......@@ -1529,8 +1528,6 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
@add_start_docstrings("The reader of REALM.", REALM_START_DOCSTRING)
class RealmReader(RealmPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", "cls"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -352,10 +352,10 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self.value = nn.Linear(self.hidden_size, self.all_head_size, bias=False)
# save mask value here. Need fp32 and fp16 mask values
self.register_buffer("self_mask_value_float16", torch.tensor(-1e3))
self.register_buffer("self_mask_value_float32", torch.tensor(-1e5))
self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9))
self.register_buffer("self_mask_value_float16", torch.tensor(-1e3), persistent=False)
self.register_buffer("self_mask_value_float32", torch.tensor(-1e5), persistent=False)
self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)
def forward(
self,
......@@ -1049,8 +1049,8 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
self.dropout = config.local_attention_probs_dropout_prob
# save mask value here
self.register_buffer("mask_value_float16", torch.tensor(-1e4))
self.register_buffer("mask_value_float32", torch.tensor(-1e9))
self.register_buffer("mask_value_float16", torch.tensor(-1e4), persistent=False)
self.register_buffer("mask_value_float32", torch.tensor(-1e9), persistent=False)
def forward(
self,
......@@ -2185,7 +2185,6 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
......
......@@ -158,7 +158,9 @@ class RemBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -654,7 +656,6 @@ class RemBertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_rembert
base_model_prefix = "rembert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1016,7 +1017,6 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
"""RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING
)
class RemBertForCausalLM(RemBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment