Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e86cb6c
Unverified
Commit
7e86cb6c
authored
Jun 25, 2024
by
Raushan Turganbay
Committed by
GitHub
Jun 25, 2024
Browse files
Siglip: add `_no_split_module` (#31566)
* device-map siglip * move split modules to PretrainedSigLip
parent
74b92c62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
src/transformers/models/siglip/modeling_siglip.py
src/transformers/models/siglip/modeling_siglip.py
+11
-4
tests/models/siglip/test_modeling_siglip.py
tests/models/siglip/test_modeling_siglip.py
+12
-0
No files found.
src/transformers/models/siglip/modeling_siglip.py
View file @
7e86cb6c
...
@@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
...
@@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
config_class
=
SiglipConfig
config_class
=
SiglipConfig
base_model_prefix
=
"siglip"
base_model_prefix
=
"siglip"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"SiglipTextEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipVisionEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipMultiheadAttentionPoolingHead"
,
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
@@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
...
@@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
class
SiglipTextModel
(
SiglipPreTrainedModel
):
class
SiglipTextModel
(
SiglipPreTrainedModel
):
config_class
=
SiglipTextConfig
config_class
=
SiglipTextConfig
_no_split_modules
=
[
"SiglipTextEmbeddings"
,
"SiglipEncoderLayer"
]
def
__init__
(
self
,
config
:
SiglipTextConfig
):
def
__init__
(
self
,
config
:
SiglipTextConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
text_model
=
SiglipTextTransformer
(
config
)
self
.
text_model
=
SiglipTextTransformer
(
config
)
...
@@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
...
@@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class
SiglipVisionModel
(
SiglipPreTrainedModel
):
class
SiglipVisionModel
(
SiglipPreTrainedModel
):
config_class
=
SiglipVisionConfig
config_class
=
SiglipVisionConfig
main_input_name
=
"pixel_values"
main_input_name
=
"pixel_values"
_no_split_modules
=
[
"SiglipVisionEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipMultiheadAttentionPoolingHead"
]
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
...
@@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
text_embeds
=
text_embeds
/
text_embeds
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
text_embeds
=
text_embeds
/
text_embeds
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
# cosine similarity as logits
# cosine similarity as logits
logits_per_text
=
torch
.
matmul
(
text_embeds
,
image_embeds
.
t
())
*
self
.
logit_scale
.
exp
()
+
self
.
logit_bias
logits_per_text
=
(
torch
.
matmul
(
text_embeds
,
image_embeds
.
t
().
to
(
text_embeds
.
device
))
*
self
.
logit_scale
.
exp
()
+
self
.
logit_bias
)
logits_per_image
=
logits_per_text
.
t
()
logits_per_image
=
logits_per_text
.
t
()
loss
=
None
loss
=
None
...
...
tests/models/siglip/test_modeling_siglip.py
View file @
7e86cb6c
...
@@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
...
@@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning
=
False
test_pruning
=
False
test_resize_embeddings
=
False
test_resize_embeddings
=
False
test_attention_outputs
=
False
test_attention_outputs
=
False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload
=
False
test_disk_offload_safetensors
=
False
test_disk_offload_bin
=
False
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
...
@@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
test_pruning
=
False
test_pruning
=
False
test_resize_embeddings
=
False
test_resize_embeddings
=
False
test_attention_outputs
=
False
test_attention_outputs
=
False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload
=
False
test_disk_offload_safetensors
=
False
test_disk_offload_bin
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
model_tester
=
SiglipForImageClassificationModelTester
(
self
)
self
.
model_tester
=
SiglipForImageClassificationModelTester
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment