Unverified Commit 89b6ee49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Finishing tidying keys to ignore on load (#24535)

parent 04f46a22
......@@ -275,12 +275,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
r"language_model.lm_head.weight",
]
_no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"]
_keep_in_fp32_modules = []
......@@ -1011,7 +1005,9 @@ class InstructBlipQFormerEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
......
......@@ -176,6 +176,14 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
def test_load_save_without_tied_weights(self):
pass
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
def test_model_weights_reload_no_missing_tied_weights(self):
pass
@unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.")
def test_channels(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment