Unverified Commit 8e5d1619 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean load keys (#24505)

* Preliminary work on some models

* Fix test load missing and make sure nonpersistent buffers are tested

* Always ignore nonpersistent buffers if in state_dict

* Treat models

* More models

* Treat remaining models

* Fix quality

* Fix tests

* Remove draft

* This test is not needed anymore

* Fix copies

* Fix last test

* Newly added models

* Fix last tests

* Address review comments
parent 53194991
......@@ -412,7 +412,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
config_class = ErnieMConfig
base_model_prefix = "ernie_m"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -96,7 +96,7 @@ class RotaryEmbedding(torch.nn.Module):
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = None
self._cos_cached = None
......@@ -178,7 +178,9 @@ class EsmEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
......@@ -783,7 +785,6 @@ class EsmModel(EsmPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = False
def __init__(self, config, add_pooling_layer=True):
......@@ -960,8 +961,6 @@ class EsmModel(EsmPreTrainedModel):
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config):
......@@ -1081,8 +1080,6 @@ class EsmLMHead(nn.Module):
ESM_START_DOCSTRING,
)
class EsmForSequenceClassification(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1177,9 +1174,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
ESM_START_DOCSTRING,
)
class EsmForTokenClassification(EsmPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -378,8 +378,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
class FlaubertModel(FlaubertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): # , dico, is_encoder, with_output):
super().__init__(config)
......@@ -448,7 +446,6 @@ class FlaubertModel(FlaubertPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False)
......@@ -654,7 +651,6 @@ class FlaubertModel(FlaubertPreTrainedModel):
)
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config):
......
......@@ -387,7 +387,9 @@ class FlavaTextEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
......@@ -1724,12 +1726,6 @@ class FlavaGlobalContrastiveHead(nn.Module):
)
class FlavaForPreTraining(FlavaPreTrainedModel):
# Those are linked to xxx.bias
_keys_to_ignore_on_load_missing = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
_tied_weights_keys = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
......
......@@ -114,7 +114,9 @@ class FNetEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
......@@ -411,7 +413,6 @@ class FNetPreTrainedModel(PreTrainedModel):
config_class = FNetConfig
base_model_prefix = "fnet"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -621,7 +622,6 @@ class FNetModel(FNetPreTrainedModel):
FNET_START_DOCSTRING,
)
class FNetForPreTraining(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......@@ -716,7 +716,6 @@ class FNetForPreTraining(FNetPreTrainedModel):
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
......
......@@ -1034,7 +1034,6 @@ def _get_shape(t):
FSMT_START_DOCSTRING,
)
class FSMTModel(PretrainedFSMTModel):
_keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig):
......@@ -1172,15 +1171,6 @@ class FSMTModel(PretrainedFSMTModel):
)
class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
"decoder.output_projection.weight",
]
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
_tied_weights_keys = ["model.decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig):
......
......@@ -1190,7 +1190,6 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
@add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
class FunnelForMaskedLM(FunnelPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FunnelConfig) -> None:
......
......@@ -109,7 +109,9 @@ class GitEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -510,7 +512,6 @@ class GitPreTrainedModel(PreTrainedModel):
config_class = GitConfig
base_model_prefix = "git"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -623,7 +624,7 @@ class GitVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
......
......@@ -668,9 +668,6 @@ DEPARALLELIZE_DOCSTRING = r"""
GPT2_START_DOCSTRING,
)
class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -957,8 +954,6 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -1151,8 +1146,6 @@ input sequence).
GPT2_START_DOCSTRING,
)
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -1381,9 +1374,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1605,9 +1595,6 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -500,8 +500,6 @@ GPT_BIGCODE_INPUTS_DOCSTRING = r"""
GPT_BIGCODE_START_DOCSTRING,
)
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.multi_query = config.multi_query
......@@ -722,7 +720,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
GPT_BIGCODE_START_DOCSTRING,
)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -876,8 +873,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
GPT_BIGCODE_START_DOCSTRING,
)
class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -145,8 +145,8 @@ class GPTNeoSelfAttention(nn.Module):
if attention_type == "local":
bias = torch.bitwise_xor(bias, torch.tril(bias, -config.window_size))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.register_buffer("bias", bias, persistent=False)
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
self.resid_dropout = nn.Dropout(float(config.resid_dropout))
......@@ -663,12 +663,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"h\.\d+\.attn\.masked_bias",
r"lm_head.weight",
r"h\.\d+\.attn\.attention\.bias",
]
_keys_to_ignore_on_save = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -820,8 +814,6 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1025,8 +1017,6 @@ class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForQuestionAnswering(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -100,8 +100,9 @@ class GPTNeoXAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
......@@ -600,7 +601,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
"""GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING
)
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config):
......@@ -775,8 +775,6 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -971,8 +969,6 @@ class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXForQuestionAnswering(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -591,7 +591,6 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
GPT_NEOX_JAPANESE_START_DOCSTRING,
)
class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config):
......
......@@ -734,7 +734,6 @@ class GPTJModel(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
......@@ -933,8 +932,6 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForSequenceClassification(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1059,8 +1056,6 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -1111,7 +1111,6 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
GPTSAN_JAPANESE_START_DOCSTRING,
)
class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GPTSanJapaneseConfig):
......
......@@ -714,7 +714,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
config_class = GraphormerConfig
base_model_prefix = "graphormer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
main_input_name_nodes = "input_nodes"
main_input_name_edges = "input_edges"
......
......@@ -450,7 +450,9 @@ class GroupViTTextEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
......@@ -767,7 +769,6 @@ class GroupViTPreTrainedModel(PreTrainedModel):
config_class = GroupViTConfig
base_model_prefix = "groupvit"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -869,7 +869,6 @@ class HubertPreTrainedModel(PreTrainedModel):
base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -80,7 +80,9 @@ class IBertEmbeddings(nn.Module):
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# End copy
......@@ -740,8 +742,6 @@ class IBertModel(IBertPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
......@@ -854,8 +854,6 @@ class IBertModel(IBertPreTrainedModel):
@add_start_docstrings("""I-BERT Model with a `language modeling` head on top.""", IBERT_START_DOCSTRING)
class IBertForMaskedLM(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
def __init__(self, config):
......@@ -969,8 +967,6 @@ class IBertLMHead(nn.Module):
IBERT_START_DOCSTRING,
)
class IBertForSequenceClassification(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1064,8 +1060,6 @@ class IBertForSequenceClassification(IBertPreTrainedModel):
IBERT_START_DOCSTRING,
)
class IBertForMultipleChoice(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1156,9 +1150,6 @@ class IBertForMultipleChoice(IBertPreTrainedModel):
IBERT_START_DOCSTRING,
)
class IBertForTokenClassification(IBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1256,9 +1247,6 @@ class IBertClassificationHead(nn.Module):
IBERT_START_DOCSTRING,
)
class IBertForQuestionAnswering(IBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......
......@@ -183,8 +183,9 @@ class ImageGPTAttention(nn.Module):
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
......@@ -613,8 +614,6 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTModel(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
......@@ -893,7 +892,6 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ImageGPTConfig):
......@@ -1085,8 +1083,6 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
self.num_labels = config.num_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment