Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e86cb6c
Unverified
Commit
7e86cb6c
authored
Jun 25, 2024
by
Raushan Turganbay
Committed by
GitHub
Jun 25, 2024
Browse files
Siglip: add `_no_split_module` (#31566)
* device-map siglip * move split modules to PretrainedSigLip
parent
74b92c62
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
src/transformers/models/siglip/modeling_siglip.py
src/transformers/models/siglip/modeling_siglip.py
+11
-4
tests/models/siglip/test_modeling_siglip.py
tests/models/siglip/test_modeling_siglip.py
+12
-0
No files found.
src/transformers/models/siglip/modeling_siglip.py
View file @
7e86cb6c
...
@@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
...
@@ -496,6 +496,13 @@ class SiglipPreTrainedModel(PreTrainedModel):
config_class
=
SiglipConfig
config_class
=
SiglipConfig
base_model_prefix
=
"siglip"
base_model_prefix
=
"siglip"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"SiglipTextEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipVisionEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipMultiheadAttentionPoolingHead"
,
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
@@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
...
@@ -816,8 +823,6 @@ class SiglipTextTransformer(nn.Module):
class
SiglipTextModel
(
SiglipPreTrainedModel
):
class
SiglipTextModel
(
SiglipPreTrainedModel
):
config_class
=
SiglipTextConfig
config_class
=
SiglipTextConfig
_no_split_modules
=
[
"SiglipTextEmbeddings"
,
"SiglipEncoderLayer"
]
def
__init__
(
self
,
config
:
SiglipTextConfig
):
def
__init__
(
self
,
config
:
SiglipTextConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
text_model
=
SiglipTextTransformer
(
config
)
self
.
text_model
=
SiglipTextTransformer
(
config
)
...
@@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
...
@@ -959,7 +964,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
class
SiglipVisionModel
(
SiglipPreTrainedModel
):
class
SiglipVisionModel
(
SiglipPreTrainedModel
):
config_class
=
SiglipVisionConfig
config_class
=
SiglipVisionConfig
main_input_name
=
"pixel_values"
main_input_name
=
"pixel_values"
_no_split_modules
=
[
"SiglipVisionEmbeddings"
,
"SiglipEncoderLayer"
,
"SiglipMultiheadAttentionPoolingHead"
]
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
...
@@ -1222,7 +1226,10 @@ class SiglipModel(SiglipPreTrainedModel):
text_embeds
=
text_embeds
/
text_embeds
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
text_embeds
=
text_embeds
/
text_embeds
.
norm
(
p
=
2
,
dim
=-
1
,
keepdim
=
True
)
# cosine similarity as logits
# cosine similarity as logits
logits_per_text
=
torch
.
matmul
(
text_embeds
,
image_embeds
.
t
())
*
self
.
logit_scale
.
exp
()
+
self
.
logit_bias
logits_per_text
=
(
torch
.
matmul
(
text_embeds
,
image_embeds
.
t
().
to
(
text_embeds
.
device
))
*
self
.
logit_scale
.
exp
()
+
self
.
logit_bias
)
logits_per_image
=
logits_per_text
.
t
()
logits_per_image
=
logits_per_text
.
t
()
loss
=
None
loss
=
None
...
...
tests/models/siglip/test_modeling_siglip.py
View file @
7e86cb6c
...
@@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
...
@@ -443,6 +443,12 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning
=
False
test_pruning
=
False
test_resize_embeddings
=
False
test_resize_embeddings
=
False
test_attention_outputs
=
False
test_attention_outputs
=
False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload
=
False
test_disk_offload_safetensors
=
False
test_disk_offload_bin
=
False
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
...
@@ -618,6 +624,12 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
test_pruning
=
False
test_pruning
=
False
test_resize_embeddings
=
False
test_resize_embeddings
=
False
test_attention_outputs
=
False
test_attention_outputs
=
False
# MP works but offload doesn't work when the MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload
=
False
test_disk_offload_safetensors
=
False
test_disk_offload_bin
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
model_tester
=
SiglipForImageClassificationModelTester
(
self
)
self
.
model_tester
=
SiglipForImageClassificationModelTester
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment