Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1477ffc3
Unverified
Commit
1477ffc3
authored
Mar 11, 2025
by
Isotr0py
Committed by
GitHub
Mar 11, 2025
Browse files
[VLM] Cleanup siglip legacy code and fix broken paligemma multimodal processor (#14602)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
70b808fe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
76 deletions
+14
-76
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+10
-5
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+4
-71
No files found.
vllm/model_executor/models/paligemma.py
View file @
1477ffc3
...
@@ -24,9 +24,10 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -24,9 +24,10 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
,
get_max_siglip_image_tokens
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vision_encoder_info
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
...
@@ -67,6 +68,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def
get_hf_config
(
self
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
PaliGemmaConfig
)
return
self
.
ctx
.
get_hf_config
(
PaliGemmaConfig
)
def
get_vision_encoder_info
(
self
):
return
get_vision_encoder_info
(
self
.
get_hf_config
())
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
...
@@ -78,9 +82,8 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
...
@@ -78,9 +82,8 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
return
{
"image"
:
self
.
get_num_image_tokens
()}
return
{
"image"
:
self
.
get_num_image_tokens
()}
def
get_num_image_tokens
(
self
)
->
int
:
def
get_num_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vision_encoder_info
=
self
.
get_vision_encoder_info
()
vision_config
=
hf_config
.
vision_config
return
vision_encoder_info
.
get_max_image_tokens
()
return
get_max_siglip_image_tokens
(
vision_config
)
class
PaliGemmaDummyInputsBuilder
(
class
PaliGemmaDummyInputsBuilder
(
...
@@ -173,8 +176,10 @@ class PaliGemmaMultiModalProcessor(
...
@@ -173,8 +176,10 @@ class PaliGemmaMultiModalProcessor(
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
mm_inputs
=
super
().
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
)
mm_inputs
=
super
().
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
,
return_mm_hashes
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
...
...
vllm/model_executor/models/siglip.py
View file @
1477ffc3
...
@@ -6,7 +6,6 @@ import math
...
@@ -6,7 +6,6 @@ import math
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
transformers
import
SiglipVisionConfig
...
@@ -20,74 +19,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -20,74 +19,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
consecutive_placeholder_ranges
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return
image_size
//
patch_size
def
get_siglip_num_patches
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
grid_length
=
get_siglip_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
)
return
grid_length
*
grid_length
def
get_siglip_image_feature_size
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
def
get_max_siglip_image_tokens
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_image_feature_size
(
hf_config
)
def
dummy_seq_data_for_siglip
(
hf_config
:
SiglipVisionConfig
,
seq_len
:
int
,
num_images
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
mm_key
:
str
=
"image"
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
return
SequenceData
.
from_prompt_token_counts
(
(
image_token_id
,
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
),
{
mm_key
:
consecutive_placeholder_ranges
(
num_items
=
num_images
,
item_size
=
image_feature_size
)
}
def
dummy_image_for_siglip
(
hf_config
:
SiglipVisionConfig
,
num_images
:
int
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
class
SiglipEncoderInfo
(
VisionEncoderInfo
[
SiglipVisionConfig
]):
class
SiglipEncoderInfo
(
VisionEncoderInfo
[
SiglipVisionConfig
]):
def
get_num_image_tokens
(
def
get_num_image_tokens
(
...
@@ -96,10 +31,10 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
...
@@ -96,10 +31,10 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
)
->
int
:
)
->
int
:
return
get_siglip_image_feature_size
(
self
.
vision_config
)
return
self
.
get_patch_grid_length
()
**
2
def
get_max_image_tokens
(
self
)
->
int
:
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_siglip_image_tokens
(
self
.
vision_config
)
return
self
.
get_patch_grid_length
()
**
2
def
get_image_size
(
self
)
->
int
:
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
return
self
.
vision_config
.
image_size
...
@@ -108,10 +43,8 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
...
@@ -108,10 +43,8 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
return
self
.
vision_config
.
patch_size
return
self
.
vision_config
.
patch_size
def
get_patch_grid_length
(
self
)
->
int
:
def
get_patch_grid_length
(
self
)
->
int
:
return
get_siglip_patch_grid_length
(
image_size
,
patch_size
=
self
.
get_image_size
(),
self
.
get_patch_size
()
image_size
=
self
.
vision_config
.
image_size
,
return
image_size
//
patch_size
patch_size
=
self
.
vision_config
.
patch_size
,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment