"vscode:/vscode.git/clone" did not exist on "447de34ddedfcb7caa2a1e03a5fe74e82f2e377f"
Unverified Commit 7e86cb6c authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Siglip: add `_no_split_module` (#31566)

* device-map siglip

* move split modules to PretrainedSigLip
parent 74b92c62
......@@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
"SiglipVisionEmbeddings",
"SiglipEncoderLayer",
"SiglipMultiheadAttentionPoolingHead",
]
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
......@@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead"]
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
......@@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
logits_per_text = (
torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
+ self.logit_bias
)
logits_per_image = logits_per_text.t()
loss = None
......
......@@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def setUp(self):
......@@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
def setUp(self):
self.model_tester = SiglipForImageClassificationModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment