Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
372b2e76
Unverified
Commit
372b2e76
authored
Feb 13, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 12, 2026
Browse files
[Bugfix] Standardize getting number of image patches/tokens (#34358)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
6afa587d
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
257 additions
and
266 deletions
+257
-266
tests/kernels/core/test_mrope.py
tests/kernels/core/test_mrope.py
+2
-22
tests/models/multimodal/generation/test_common.py
tests/models/multimodal/generation/test_common.py
+0
-6
tests/models/multimodal/processing/test_gemma3.py
tests/models/multimodal/processing/test_gemma3.py
+1
-0
tests/models/multimodal/processing/test_idefics3.py
tests/models/multimodal/processing/test_idefics3.py
+11
-1
tests/models/multimodal/processing/test_qwen2_vl.py
tests/models/multimodal/processing/test_qwen2_vl.py
+1
-0
tests/models/multimodal/processing/test_smolvlm.py
tests/models/multimodal/processing/test_smolvlm.py
+11
-1
vllm/model_executor/models/cohere2_vision.py
vllm/model_executor/models/cohere2_vision.py
+10
-31
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+27
-12
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+40
-52
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+2
-8
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+1
-4
vllm/model_executor/models/hunyuan_vision.py
vllm/model_executor/models/hunyuan_vision.py
+23
-11
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+22
-42
vllm/model_executor/models/interns1.py
vllm/model_executor/models/interns1.py
+13
-13
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+1
-4
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+28
-16
vllm/model_executor/models/lfm2_vl.py
vllm/model_executor/models/lfm2_vl.py
+44
-21
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+1
-4
vllm/model_executor/models/molmo2.py
vllm/model_executor/models/molmo2.py
+18
-14
vllm/model_executor/models/ovis2_5.py
vllm/model_executor/models/ovis2_5.py
+1
-4
No files found.
tests/kernels/core/test_mrope.py
View file @
372b2e76
...
...
@@ -4,8 +4,6 @@ from typing import NamedTuple
import
pytest
import
torch
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
...
...
@@ -46,31 +44,13 @@ class MRoPETestInfo(NamedTuple):
marks
:
list
[
pytest
.
MarkDecorator
]
=
[]
TRANSFORMERS_BASE_VERSION
=
Version
(
TRANSFORMERS_VERSION
).
base_version
MODELS_TO_TEST
=
[
MRoPETestInfo
(
model_name
=
"zai-org/GLM-4.1V-9B-Thinking"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2.5-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-4B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
],
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
],
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-4B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
),
]
num_tokens_list
=
[
11
,
8192
]
...
...
tests/models/multimodal/generation/test_common.py
View file @
372b2e76
...
...
@@ -961,12 +961,6 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt
=
{
"image"
:
4
},
)
],
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
==
Version
(
"4.57.1"
),
reason
=
"This model is broken in Transformers v4.57.1"
,
)
],
),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention"
:
VLMTestInfo
(
...
...
tests/models/multimodal/processing/test_gemma3.py
View file @
372b2e76
...
...
@@ -168,6 +168,7 @@ def test_get_image_size_with_most_features(
image_width
=
max_image_size
.
width
,
image_height
=
max_image_size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
prompt
=
"<start_of_image>"
...
...
tests/models/multimodal/processing/test_idefics3.py
View file @
372b2e76
...
...
@@ -3,7 +3,9 @@
"""Tests for Idefics3's multimodal preprocessing kwargs."""
import
pytest
from
packaging.version
import
Version
from
transformers
import
Idefics3Config
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"5.2.0"
),
reason
=
"See https://github.com/huggingface/transformers/pull/43948"
,
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"HuggingFaceM4/Idefics3-8B-Llama3"
])
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
),
...
...
@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct
hf_processor
=
processor
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processed_inputs
=
hf_processor
(
text
=
prompt
,
images
=
mm_data
[
"image"
])
hf_processed_inputs
=
hf_processor
(
text
=
prompt
,
images
=
mm_data
[
"image"
],
**
processor
.
info
.
ctx
.
get_merged_mm_kwargs
(
hf_processor_mm_kwargs
),
)
assert
processed_inputs
[
"prompt_token_ids"
]
==
hf_processed_inputs
[
"input_ids"
][
0
]
# Ensure we have the right number of placeholders per num_crops size
...
...
tests/models/multimodal/processing/test_qwen2_vl.py
View file @
372b2e76
...
...
@@ -82,6 +82,7 @@ def test_get_image_size_with_most_features(
image_width
=
max_image_size
.
width
,
image_height
=
max_image_size
.
height
,
image_processor
=
hf_processor
.
image_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
prompt
=
"<|vision_start|><|image_pad|><|vision_end|>"
...
...
tests/models/multimodal/processing/test_smolvlm.py
View file @
372b2e76
...
...
@@ -3,7 +3,9 @@
"""Tests for smolvlm's multimodal preprocessing kwargs."""
import
pytest
from
packaging.version
import
Version
from
transformers
import
SmolVLMConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from
...utils
import
build_model_context
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"5.2.0"
),
reason
=
"See https://github.com/huggingface/transformers/pull/43948"
,
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
])
@
pytest
.
mark
.
parametrize
(
(
"mm_processor_kwargs"
,
"expected_toks_per_img"
),
...
...
@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct
hf_processor
=
processor
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processed_inputs
=
hf_processor
(
text
=
prompt
,
images
=
mm_data
[
"image"
])
hf_processed_inputs
=
hf_processor
(
text
=
prompt
,
images
=
mm_data
[
"image"
],
**
processor
.
info
.
ctx
.
get_merged_mm_kwargs
(
hf_processor_mm_kwargs
),
)
assert
processed_inputs
[
"prompt_token_ids"
]
==
hf_processed_inputs
[
"input_ids"
][
0
]
# Ensure we have the right number of placeholders per num_crops size
...
...
vllm/model_executor/models/cohere2_vision.py
View file @
372b2e76
...
...
@@ -11,7 +11,7 @@ from torch import nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers.models.cohere2_vision
import
Cohere2VisionConfig
from
transformers.models.cohere2_vision.image_processing_cohere2_vision_fast
import
(
# noqa: E501
get_optimal_tiled_canv
as
,
Cohere2VisionImageProcessorF
as
t
,
)
from
transformers.models.cohere2_vision.processing_cohere2_vision
import
(
Cohere2VisionProcessor
,
...
...
@@ -166,43 +166,20 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Cohere2VisionProcessor
|
None
,
processor
:
Cohere2VisionProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
"""
Calculate the number of image patches for a given image.
Uses the HF processor to determine the actual number of patches.
"""
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image_processor
=
processor
.
image_processor
image_processor
:
Cohere2VisionImageProcessorFast
=
processor
.
image_processor
# The current implementation of get_number_of_image_patches
# is incorrect, so we patch it here.
# TODO: Revert once
# https://github.com/huggingface/transformers/pull/40312 is released.
# return image_processor.get_number_of_image_patches(image_height,
# image_width, {})
min_patches
=
image_processor
.
min_patches
max_patches
=
image_processor
.
max_patches
patch_size
=
image_processor
.
size
crop_to_patches
=
image_processor
.
crop_to_patches
if
not
crop_to_patches
:
return
1
num_columns
,
num_rows
=
get_optimal_tiled_canvas
(
(
image_height
,
image_width
),
(
patch_size
[
"height"
],
patch_size
[
"width"
]),
min_patches
,
max_patches
,
return
image_processor
.
get_number_of_image_patches
(
image_height
,
image_width
,
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
),
)
num_patches
=
num_columns
*
num_rows
if
num_patches
>
1
:
num_patches
+=
1
# Thumbnail image
return
num_patches
class
Cohere2VisionDummyInputsBuilder
(
...
...
@@ -271,6 +248,7 @@ class Cohere2VisionMultiModalProcessor(
image_width
=
parsed_images
.
get_image_size
(
i
).
width
,
image_height
=
parsed_images
.
get_image_size
(
i
).
height
,
processor
=
hf_processor
,
mm_kwargs
=
mm_kwargs
,
)
for
i
in
range
(
len
(
parsed_images
))
]
...
...
@@ -311,6 +289,7 @@ class Cohere2VisionMultiModalProcessor(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
patch_tokens
=
image_token
*
img_tokens_per_tile
+
img_line_break_token
repl
=
f
"
{
boi_token
}{
patch_tokens
*
num_patches
}{
eoi_token
}
"
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
372b2e76
...
...
@@ -34,7 +34,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BatchFeature
from
transformers
import
BaseImageProcessor
,
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
,
VideoDummyOptions
...
...
@@ -818,10 +818,9 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
num_frames
:
int
=
1
,
do_resize
:
bool
=
True
,
image_processor
:
Any
|
None
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
...
...
@@ -829,13 +828,16 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
spatial_conv_size
=
hf_config
.
spatial_conv_size
temporal_conv_size
=
hf_config
.
temporal_conv_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
spatial_conv_size
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
min_pixels
=
size
[
"
min_pixels
"
]
,
max_pixels
=
size
[
"
max_pixels
"
]
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
else
:
...
...
@@ -855,12 +857,14 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
:
Any
|
None
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_image_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_image_tokens
...
...
@@ -870,35 +874,43 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
,
image_processor
:
Any
|
None
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_video_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
num_frames
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_video_tokens
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
max_image_size
,
_
=
self
.
_get_vision_info
(
image_width
=
9999999
,
image_height
=
9999999
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
return
max_image_size
def
get_max_image_tokens
(
self
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
num_image_tokens
=
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
return
num_image_tokens
def
_get_max_video_frames
(
self
,
max_tokens
:
int
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
num_frames
=
0
...
...
@@ -909,7 +921,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
next_num_frames
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
if
next_max_tokens
>
max_tokens
:
...
...
@@ -942,13 +955,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
),
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
372b2e76
...
...
@@ -7,6 +7,7 @@ from typing import Annotated, Any, Literal
import
torch
from
torch
import
nn
from
transformers
import
BatchFeature
,
Gemma3Config
,
Gemma3Processor
from
transformers.models.gemma3.image_processing_gemma3
import
Gemma3ImageProcessor
from
transformers.models.gemma3.processing_gemma3
import
Gemma3ProcessorKwargs
from
vllm.config
import
VllmConfig
...
...
@@ -84,54 +85,35 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
}
def
_resolve_image_kwargs
(
self
,
processor
:
Gemma3Processor
,
keys
:
set
[
str
],
)
->
dict
[
str
,
Any
]:
image_processor
=
processor
.
image_processor
kwargs
=
processor
.
_merge_kwargs
(
Gemma3ProcessorKwargs
,
tokenizer_init_kwargs
=
processor
.
tokenizer
.
init_kwargs
,
)
images_kwargs
=
kwargs
[
"images_kwargs"
]
def
_resolve_kw
(
key
:
str
):
val
=
getattr
(
image_processor
,
key
)
if
val
is
None
:
val
=
images_kwargs
[
key
]
return
val
return
{
k
:
_resolve_kw
(
k
)
for
k
in
keys
}
def
get_num_crops
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Gemma3Processor
|
None
,
processor
:
Gemma3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
images_kwargs
=
self
.
_resolve_image_kwargs
(
processor
,
{
"do_pan_and_scan"
,
"pan_and_scan_min_crop_size"
,
"pan_and_scan_max_num_crops"
,
"pan_and_scan_min_ratio_to_activate"
,
},
)
image_processor
:
Gemma3ImageProcessor
=
processor
.
image_processor
do_pan_and_scan
=
images_kwargs
[
"do_pan_and_scan"
]
pan_and_scan_min_crop_size
=
images_kwargs
[
"pan_and_scan_min_crop_size"
]
pan_and_scan_max_num_crops
=
images_kwargs
[
"pan_and_scan_max_num_crops"
]
pan_and_scan_min_ratio_to_activate
=
images_kwargs
[
"pan_and_scan_min_ratio_to_activate"
]
images_kwargs
=
processor
.
_merge_kwargs
(
Gemma3ProcessorKwargs
,
tokenizer_init_kwargs
=
processor
.
tokenizer
.
init_kwargs
,
**
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
),
)[
"images_kwargs"
]
do_pan_and_scan
=
images_kwargs
.
get
(
"do_pan_and_scan"
,
image_processor
.
do_pan_and_scan
)
pan_and_scan_min_crop_size
=
images_kwargs
.
get
(
"pan_and_scan_min_crop_size"
,
image_processor
.
pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops
=
images_kwargs
.
get
(
"pan_and_scan_max_num_crops"
,
image_processor
.
pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate
=
images_kwargs
.
get
(
"pan_and_scan_min_ratio_to_activate"
,
image_processor
.
pan_and_scan_min_ratio_to_activate
,
)
if
not
do_pan_and_scan
:
return
0
...
...
@@ -180,17 +162,16 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Gemma3Processor
|
None
,
processor
:
Gemma3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
PromptUpdateDetails
[
str
]:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
boi_token
=
processor
.
boi_token
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
if
num_crops
==
0
:
...
...
@@ -215,15 +196,14 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Gemma3Processor
|
None
,
processor
:
Gemma3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
num_crops
=
self
.
get_num_crops
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
image_seq_len
=
processor
.
image_seq_length
...
...
@@ -231,11 +211,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
image_processor
:
Gemma3ImageProcessor
=
processor
.
image_processor
images_kwargs
=
processor
.
_merge_kwargs
(
Gemma3ProcessorKwargs
,
tokenizer_init_kwargs
=
processor
.
tokenizer
.
init_kwargs
,
**
self
.
ctx
.
get_merged_mm_kwargs
({}),
)[
"images_kwargs"
]
i
ma
ges_kwargs
=
self
.
_resolve_
image_kwargs
(
processor
,
{
"
pan_and_scan_max_num_crops
"
}
ma
x_num_crops
=
image
s
_kwargs
.
get
(
"pan_and_scan_max_num_crops"
,
image_
processor
.
pan_and_scan_max_num_crops
)
max_num_crops
=
images_kwargs
[
"pan_and_scan_max_num_crops"
]
vision_config
=
self
.
get_hf_config
().
vision_config
native_size
=
vision_config
.
image_size
...
...
@@ -303,6 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
mm_kwargs
,
)
for
size
in
image_sizes
]
...
...
@@ -339,6 +326,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
[
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
372b2e76
...
...
@@ -131,7 +131,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Gemma3nProcessor
|
None
,
processor
:
Gemma3nProcessor
,
)
->
str
:
"""
Get the replacement text for image tokens.
...
...
@@ -139,9 +139,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_image_sequence which includes
BOI token, repeated image tokens, and EOI token.
"""
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
PromptUpdateDetails
.
select_token_id
(
processor
.
full_image_sequence
,
processor
.
image_token_id
)
...
...
@@ -149,7 +146,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def
get_audio_repl
(
self
,
*
,
processor
:
Gemma3nProcessor
|
None
,
processor
:
Gemma3nProcessor
,
)
->
str
:
"""
Get the replacement text for audio tokens.
...
...
@@ -157,9 +154,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_audio_sequence which includes
BOA token, repeated audio tokens, and EOA token.
"""
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
# Return the full audio sequence as defined by the processor
return
PromptUpdateDetails
.
select_token_id
(
processor
.
full_audio_sequence
,
processor
.
audio_token_id
...
...
vllm/model_executor/models/h2ovl.py
View file @
372b2e76
...
...
@@ -424,12 +424,9 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
H2OVLProcessor
|
None
,
processor
:
H2OVLProcessor
,
use_msac
:
bool
|
None
=
None
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
processor
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
...
...
vllm/model_executor/models/hunyuan_vision.py
View file @
372b2e76
...
...
@@ -78,7 +78,10 @@ from vllm.transformers_utils.configs.hunyuan_vl import (
HunYuanVLVisionConfig
,
)
from
vllm.transformers_utils.processors.hunyuan_vl
import
HunYuanVLProcessor
from
vllm.transformers_utils.processors.hunyuan_vl_image
import
smart_resize
from
vllm.transformers_utils.processors.hunyuan_vl_image
import
(
HunYuanVLImageProcessor
,
smart_resize
,
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
...
...
@@ -596,7 +599,7 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
def
get_image_processor
(
self
,
**
kwargs
:
object
,
)
->
HunYuanVLProcessor
:
)
->
HunYuanVL
Image
Processor
:
return
self
.
get_hf_processor
(
**
kwargs
).
image_processor
def
get_data_parser
(
self
):
...
...
@@ -624,23 +627,24 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
num_frames
:
int
=
1
,
do_resize
:
bool
=
True
,
image_processor
:
HunYuanVLProcessor
|
None
,
image_processor
:
HunYuanVLImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
patch_size
=
vision_config
.
patch_size
spatial_merge_size
=
vision_config
.
spatial_merge_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
spatial_merge_size
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
min_pixels
=
size
[
"shortest_edge"
]
,
max_pixels
=
size
[
"longest_edge"
]
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
else
:
...
...
@@ -662,29 +666,37 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
:
HunYuanVLProcessor
|
None
,
image_processor
:
HunYuanVLImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_image_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_image_tokens
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
max_image_size
,
_
=
self
.
_get_vision_info
(
image_width
=
512
,
image_height
=
8192
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
return
max_image_size
def
get_max_image_tokens
(
self
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
...
...
vllm/model_executor/models/idefics3.py
View file @
372b2e76
...
...
@@ -16,7 +16,6 @@
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Literal
,
TypeAlias
...
...
@@ -168,54 +167,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Idefics3Processor
|
None
,
)
->
tuple
[
int
,
int
]:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
processor
:
Idefics3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
int
,
int
,
int
]:
image_processor
:
Idefics3ImageProcessor
=
processor
.
image_processor
max_image_size
=
image_processor
.
max_image_size
[
"longest_edge"
]
size
=
image_processor
.
size
[
"longest_edge"
]
assert
size
%
max_image_size
==
0
,
(
"`longest_edge` in image_processor's `size` must be divisible by "
"`longest_edge` in `max_image_size`, this may be caused by "
"incorrect mm_kwargs override."
)
resized_height
,
resized_width
=
self
.
_get_resize_output_image_size
(
image_width
=
image_width
,
image_height
=
image_height
,
resolution_max_side
=
size
,
return
image_processor
.
get_number_of_image_patches
(
image_height
,
image_width
,
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
),
)
if
resized_height
>
max_image_size
or
resized_width
>
max_image_size
:
grid_h
=
math
.
ceil
(
resized_height
/
max_image_size
)
grid_w
=
math
.
ceil
(
resized_width
/
max_image_size
)
else
:
grid_h
=
grid_w
=
0
return
grid_w
,
grid_h
def
get_num_patches
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Idefics3Processor
|
None
,
processor
:
Idefics3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
grid_w
,
grid_h
=
self
.
_get_image_feature_grid_size
(
num_patches
,
_
,
_
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
return
grid_w
*
grid_h
+
1
def
_get_image_token
(
self
,
processor
:
Idefics3Processor
|
None
)
->
tuple
[
str
,
str
,
str
]:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
num_patches
def
_get_image_token
(
self
,
processor
:
Idefics3Processor
)
->
tuple
[
str
,
str
,
str
]:
image_token
=
processor
.
image_token
fake_image_token
=
processor
.
fake_image_token
global_image_token
=
processor
.
global_image_tag
...
...
@@ -226,11 +206,9 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Idefics3Processor
|
None
,
processor
:
Idefics3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
str
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image_token
,
fake_image_token
,
global_img_token
=
self
.
_get_image_token
(
processor
)
...
...
@@ -241,10 +219,11 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
global_img_placeholder
=
fake_image_token
+
global_img_token
+
p_img
tile_img_placeholder
=
fake_image_token
+
grid_placeholder
+
p_img
grid_
w
,
grid_
h
=
self
.
_get_image_feature_grid_size
(
_
,
grid_
h
,
grid_
w
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
if
grid_w
==
0
and
grid_h
==
0
:
return
global_img_placeholder
+
fake_image_token
...
...
@@ -272,15 +251,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Idefics3Processor
|
None
,
processor
:
Idefics3Processor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
num_patches
=
self
.
get_num_patches
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_patches
*
processor
.
image_seq_len
...
...
@@ -353,6 +331,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
mm_kwargs
,
)
for
size
in
image_sizes
]
...
...
@@ -398,6 +377,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
PromptUpdateDetails
.
select_text
(
...
...
vllm/model_executor/models/interns1.py
View file @
372b2e76
...
...
@@ -197,20 +197,18 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
GotOcr2ImageProcessorFast
|
None
=
None
,
processor
:
InternVLProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
().
image_processor
image_processor
:
GotOcr2ImageProcessorFast
=
processor
.
image_processor
if
not
isinstance
(
processor
,
GotOcr2ImageProcessorFast
):
raise
ValueError
(
f
"GotOcr2ImageProcessorFast is expected but got
{
type
(
processor
)
}
"
)
num_image_patches
=
processor
.
get_number_of_image_patches
(
image_height
,
image_width
,
images_kwargs
=
dict
()
num_image_patches
=
image_processor
.
get_number_of_image_patches
(
image_height
,
image_width
,
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
),
)
num_image_tokens
=
self
.
get_hf_processor
().
image_seq_length
*
num_image_patches
return
num_image_
token
s
return
processor
.
image_seq_length
*
num_image_
patche
s
def
resolve_target_ratios
(
self
,
use_thumbnail
:
bool
|
None
=
None
):
image_processor
=
self
.
get_hf_processor
().
image_processor
...
...
@@ -243,7 +241,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
feat_size
=
self
.
get_num_image_tokens
(
image_width
=
width
,
image_height
=
height
,
processor
=
processor
.
image_processor
,
processor
=
processor
,
mm_kwargs
=
{},
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
...
...
@@ -262,7 +261,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
processor
.
image_processor
,
processor
=
processor
,
mm_kwargs
=
{},
)
def
get_num_frames_with_most_features
(
...
...
vllm/model_executor/models/internvl.py
View file @
372b2e76
...
...
@@ -705,11 +705,8 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
BaseInternVLProcessor
|
None
,
processor
:
BaseInternVLProcessor
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
processor
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
...
...
vllm/model_executor/models/keye.py
View file @
372b2e76
...
...
@@ -10,7 +10,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
transformers
import
PretrainedConfig
from
transformers
import
BaseImageProcessor
,
PretrainedConfig
from
transformers.activations
import
GELUActivation
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
...
...
@@ -1011,24 +1011,25 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
num_frames
:
int
=
1
,
do_resize
:
bool
=
True
,
image_processor
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
patch_size
=
vision_config
.
patch_size
merge_size
=
vision_config
.
spatial_merge_size
temporal_patch_size
=
1
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
merge_size
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
min_pixels
=
size
[
"
min_pixels
"
]
,
max_pixels
=
size
[
"
max_pixels
"
]
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
else
:
...
...
@@ -1050,12 +1051,14 @@ class KeyeProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_image_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_image_tokens
...
...
@@ -1065,36 +1068,42 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
,
image_processor
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_video_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
num_frames
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_video_tokens
def
get_image_size_with_most_features
(
self
,
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_image_processor
()
max_image_size
,
_
=
self
.
_get_vision_info
(
image_width
=
self
.
get_max_image_size
(),
image_height
=
self
.
get_max_image_size
(),
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
return
max_image_size
def
get_max_image_tokens
(
self
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
def
_get_max_video_frames
(
self
,
max_tokens
:
int
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
num_frames
=
0
...
...
@@ -1105,7 +1114,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
next_num_frames
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
if
next_max_tokens
>
max_tokens
:
...
...
@@ -1130,13 +1140,15 @@ class KeyeProcessingInfo(BaseProcessingInfo):
return
max
(
max_frames_per_video
,
1
)
def
get_max_video_tokens
(
self
,
seq_len
:
int
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
),
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
...
...
vllm/model_executor/models/lfm2_vl.py
View file @
372b2e76
...
...
@@ -176,7 +176,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
min_tiles
:
int
,
max_tiles
:
int
,
tile_size
:
int
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
,
int
]:
aspect_ratio
=
width
/
height
target_ratios
=
self
.
_target_ratios
(
min_tiles
,
max_tiles
)
# find best matching grid configuration
...
...
@@ -190,18 +190,27 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self
,
image_width
:
int
,
image_height
:
int
,
processor
:
Lfm2VlProcessor
|
None
,
)
->
tuple
[
int
,
in
t
]
:
if
processor
is
None
:
processor
=
self
.
get_
image_processor
()
processor
:
Lfm2VlProcessor
,
mm_kwargs
:
Mapping
[
str
,
objec
t
]
,
)
->
tuple
[
int
,
int
,
int
]
:
image_processor
:
Lfm2VlImageProcessorFast
=
processor
.
image_processor
downsample_factor
=
processor
.
image_processor
.
downsample_factor
encoder_patch_size
=
processor
.
image_processor
.
encoder_patch_size
max_pixels_tolerance
=
processor
.
image_processor
.
max_pixels_tolerance
min_tiles
=
processor
.
image_processor
.
min_tiles
max_tiles
=
processor
.
image_processor
.
max_tiles
max_image_tokens
=
processor
.
image_processor
.
max_image_tokens
tile_size
=
processor
.
image_processor
.
tile_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
downsample_factor
=
mm_kwargs
.
get
(
"downsample_factor"
,
image_processor
.
downsample_factor
)
encoder_patch_size
=
mm_kwargs
.
get
(
"encoder_patch_size"
,
image_processor
.
encoder_patch_size
)
max_pixels_tolerance
=
mm_kwargs
.
get
(
"max_pixels_tolerance"
,
image_processor
.
max_pixels_tolerance
)
min_tiles
=
mm_kwargs
.
get
(
"min_tiles"
,
image_processor
.
min_tiles
)
max_tiles
=
mm_kwargs
.
get
(
"max_tiles"
,
image_processor
.
max_tiles
)
max_image_tokens
=
mm_kwargs
.
get
(
"max_image_tokens"
,
image_processor
.
max_image_tokens
)
tile_size
=
mm_kwargs
.
get
(
"tile_size"
,
image_processor
.
tile_size
)
do_image_splitting
=
not
min_tiles
==
max_tiles
==
1
is_image_large
=
self
.
_is_image_too_large
(
...
...
@@ -235,12 +244,14 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Lfm2VlProcessor
|
None
,
processor
:
Lfm2VlProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
_
,
total_patches
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
return
total_patches
...
...
@@ -249,11 +260,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
spatial_shapes
:
torch
.
Tensor
,
processor
:
Lfm2VlProcessor
|
None
,
processor
:
Lfm2VlProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
str
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
grid_placeholder
=
"<|img_row_{n_h}_col_{n_w}|>"
image_token
=
processor
.
image_token
image_start_token
=
processor
.
image_start_token
...
...
@@ -263,6 +272,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
num_thumbnail_tokens
,
num_tokens_per_tile
=
self
.
get_num_image_tokens
(
spatial_shapes
=
spatial_shapes
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
tile_img_placeholder
=
grid_placeholder
+
(
image_token
*
num_tokens_per_tile
)
...
...
@@ -270,6 +280,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
mm_kwargs
=
mm_kwargs
,
)
if
grid_w
>
1
or
grid_h
>
1
:
...
...
@@ -295,15 +306,25 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self
,
*
,
spatial_shapes
:
torch
.
Tensor
,
processor
:
Lfm2VlProcessor
|
None
,
processor
:
Lfm2VlProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
int
,
int
]:
tile_size
=
processor
.
image_processor
.
tile_size
downsample_factor
=
processor
.
image_processor
.
downsample_factor
encoder_patch_size
=
processor
.
image_processor
.
encoder_patch_size
image_processor
:
Lfm2VlImageProcessorFast
=
processor
.
image_processor
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
downsample_factor
=
mm_kwargs
.
get
(
"downsample_factor"
,
image_processor
.
downsample_factor
)
encoder_patch_size
=
mm_kwargs
.
get
(
"encoder_patch_size"
,
image_processor
.
encoder_patch_size
)
tile_size
=
mm_kwargs
.
get
(
"tile_size"
,
image_processor
.
tile_size
)
num_thumbnail_tokens
=
spatial_shapes
[
-
1
].
prod
()
//
(
downsample_factor
**
2
)
num_patches_tile
=
tile_size
//
encoder_patch_size
dwn_num_patches_tile
=
math
.
ceil
(
num_patches_tile
/
downsample_factor
)
num_tiles_tokens
=
dwn_num_patches_tile
*
dwn_num_patches_tile
return
num_thumbnail_tokens
,
num_tiles_tokens
...
...
@@ -372,6 +393,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
,
mm_kwargs
=
mm_kwargs
,
)
for
size
in
image_sizes
]
...
...
@@ -414,6 +436,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_height
=
image_size
.
height
,
spatial_shapes
=
spatial_shapes
,
processor
=
hf_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
PromptUpdateDetails
.
select_text
(
image_repl
,
...
...
vllm/model_executor/models/molmo.py
View file @
372b2e76
...
...
@@ -1224,11 +1224,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
MolmoProcessorWrapper
|
None
,
processor
:
MolmoProcessorWrapper
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
ncols
,
nrows
=
processor
.
get_patches_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
...
...
vllm/model_executor/models/molmo2.py
View file @
372b2e76
...
...
@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
*
,
image_height
:
int
,
image_width
:
int
,
processor
:
Molmo2ProcessorWrapper
|
None
=
None
,
processor
:
Molmo2ProcessorWrapper
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
hf_processor
=
processor
.
processor
# type: ignore
hf_processor
=
processor
.
processor
resize_nrows
,
resize_cols
=
processor
.
get_base_grid_size
(
is_video
=
False
)
# start/end tokens + image patch token + col tokens
...
...
@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
self
,
*
,
num_frames
:
int
,
processor
:
Molmo2ProcessorWrapper
|
None
=
None
,
processor
:
Molmo2ProcessorWrapper
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
resize_nrows
,
resize_cols
=
processor
.
get_base_grid_size
(
is_video
=
True
)
# start/end tokens
extra
=
2
+
resize_nrows
*
(
...
...
@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
width
=
wr
*
crop_window_size
+
total_margin_pixels
feat_size
=
self
.
get_num_image_tokens
(
image_height
=
height
,
image_width
=
width
,
processor
=
processor
image_height
=
height
,
image_width
=
width
,
processor
=
processor
,
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
...
...
@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
return
largest_feature_pinpoint
def
_get_max_video_frames
(
self
,
max_tokens
:
int
)
->
int
:
num_tokens_per_frame
=
self
.
get_num_video_tokens
(
num_frames
=
1
)
def
_get_max_video_frames
(
self
,
max_tokens
:
int
,
processor
:
Molmo2ProcessorWrapper
,
)
->
int
:
num_tokens_per_frame
=
self
.
get_num_video_tokens
(
num_frames
=
1
,
processor
=
processor
,
)
max_frames
=
max_tokens
//
num_tokens_per_frame
return
max
(
max_frames
,
1
)
...
...
@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
int
:
video_processor
=
self
.
get_hf_processor
().
processor
.
video_processor
processor
=
self
.
get_hf_processor
()
video_processor
=
processor
.
processor
.
video_processor
num_frames
=
video_processor
.
num_frames
max_videos
=
mm_counts
.
get
(
"video"
,
0
)
max_total_frames
=
self
.
_get_max_video_frames
(
seq_len
)
max_total_frames
=
self
.
_get_max_video_frames
(
seq_len
,
processor
)
max_frames_per_video
=
min
(
max_total_frames
//
max
(
max_videos
,
1
),
num_frames
,
...
...
vllm/model_executor/models/ovis2_5.py
View file @
372b2e76
...
...
@@ -215,7 +215,7 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
=
1
,
)
->
tuple
[
ImageSize
,
int
]
:
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vit_config
=
hf_config
.
vit_config
patch_size
=
vit_config
.
patch_size
...
...
@@ -245,7 +245,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
next_num_frames
,
image_processor
=
None
,
)
if
next_max_tokens
>
max_tokens
:
break
...
...
@@ -270,7 +269,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
,
image_processor
:
BaseImageProcessor
|
None
,
)
->
int
:
num_video_tokens
=
self
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
num_frames
...
...
@@ -287,7 +285,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
),
image_processor
=
None
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment