Unverified Commit 66f89332 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

normalize keys_to_ignore (#17722)

parent c3c62b5d
......@@ -1252,8 +1252,8 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder\.version",
r"decoder\.version",
r"encoder.version",
r"decoder.version",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
......
......@@ -1266,11 +1266,11 @@ num_heads)`.
)
class T5Model(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
......@@ -1455,12 +1455,12 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"encoder.embed_tokens.weight",
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_keys_to_ignore_on_load_unexpected = [
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
def __init__(self, config: T5Config):
......@@ -1749,7 +1749,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"encoder.embed_tokens.weight",
]
def __init__(self, config: T5Config):
......
......@@ -1198,7 +1198,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
TRANSFO_XL_START_DOCSTRING,
)
class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment