Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
372b2e76
Unverified
Commit
372b2e76
authored
Feb 13, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 12, 2026
Browse files
[Bugfix] Standardize getting number of image patches/tokens (#34358)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
6afa587d
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
63 additions
and
66 deletions
+63
-66
vllm/model_executor/models/paddleocr_vl.py
vllm/model_executor/models/paddleocr_vl.py
+12
-8
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+1
-4
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+9
-9
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+1
-16
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+22
-11
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+11
-9
vllm/model_executor/models/skyworkr1v.py
vllm/model_executor/models/skyworkr1v.py
+1
-4
vllm/model_executor/models/smolvlm.py
vllm/model_executor/models/smolvlm.py
+1
-3
vllm/multimodal/processing/context.py
vllm/multimodal/processing/context.py
+5
-2
No files found.
vllm/model_executor/models/paddleocr_vl.py
View file @
372b2e76
...
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BaseImageProcessor
,
BatchFeature
,
PretrainedConfig
from
transformers.activations
import
GELUActivation
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPooling
,
...
...
@@ -147,21 +147,23 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
,
image_processor
:
BaseImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
patch_size
=
vision_config
.
patch_size
merge_size
=
vision_config
.
spatial_merge_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
merge_size
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
min_pixels
=
size
[
"
min_pixels
"
]
,
max_pixels
=
size
[
"
max_pixels
"
]
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
...
...
@@ -176,12 +178,13 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
hf_config
=
self
.
get_hf_config
()
image_processor
=
self
.
get_image_processor
()
# See `smart_resize` for the calculation of the image size.
merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
patch_size
=
hf_config
.
vision_config
.
patch_size
factor
=
merge_size
*
patch_size
max_num_tokens
=
self
.
get_
image_processor
()
.
max_pixels
//
(
factor
**
2
)
max_num_tokens
=
image_processor
.
max_pixels
//
(
factor
**
2
)
# Find factors of max_num_tokens close to its square root
# to create a dummy image with a reasonable aspect ratio.
h_patches
=
int
(
math
.
sqrt
(
max_num_tokens
))
...
...
@@ -276,6 +279,7 @@ class PaddleOCRVLMultiModalProcessor(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_processor
=
image_processor
,
mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
[
image_token_id
]
*
num_image_tokens
...
...
vllm/model_executor/models/phi3v.py
View file @
372b2e76
...
...
@@ -351,11 +351,8 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
ProcessorMixin
|
None
=
None
,
processor
:
ProcessorMixin
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
processor
.
calc_num_image_tokens_from_image_size
(
# type: ignore
width
=
image_width
,
height
=
image_height
,
...
...
vllm/model_executor/models/phi4mm.py
View file @
372b2e76
...
...
@@ -558,10 +558,8 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def
get_dynamic_hd
(
self
,
processor
:
ProcessorMixin
|
None
=
None
,
processor
:
ProcessorMixin
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image_processor
=
processor
.
image_processor
return
image_processor
.
dynamic_hd
...
...
@@ -715,7 +713,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
ProcessorMixin
|
None
=
None
,
processor
:
ProcessorMixin
,
)
->
int
:
hf_config
=
self
.
get_hf_config
()
vision_encoder_name
=
hf_config
.
img_processor
...
...
@@ -739,10 +737,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
return
image_num_tokens
def
get_image_size_with_most_features
(
self
,
processor
:
ProcessorMixin
|
None
=
None
,
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
hf_config
=
self
.
get_hf_config
()
vision_encoder_name
=
hf_config
.
img_processor
if
vision_encoder_name
is
None
:
...
...
@@ -874,9 +871,12 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt
,
mm_data
,
mm_kwargs
,
tok_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
num_img_tokens
=
[
self
.
info
.
get_num_image_tokens
(
image_width
=
img_size
[
0
],
image_height
=
img_size
[
1
]
image_width
=
img_size
[
0
],
image_height
=
img_size
[
1
],
processor
=
hf_processor
,
)
for
img_size
in
processed_outputs
[
"image_sizes"
]
]
...
...
vllm/model_executor/models/pixtral.py
View file @
372b2e76
...
...
@@ -217,28 +217,13 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
}
def
get_vision_config
(
self
,
processor
:
PixtralProcessorAdapter
|
None
=
None
,
):
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
PixtralVisionConfig
(
image_size
=
processor
.
image_size
,
patch_size
=
processor
.
patch_size
,
)
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
PixtralProcessorAdapter
|
None
=
None
,
processor
:
PixtralProcessorAdapter
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
ncols
,
nrows
=
processor
.
image_processor
.
_image_to_num_tokens
(
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
))
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
372b2e76
...
...
@@ -832,24 +832,25 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
num_frames
:
int
=
1
,
do_resize
:
bool
=
True
,
image_processor
:
Qwen2VLImageProcessor
|
None
,
image_processor
:
Qwen2VLImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
patch_size
=
vision_config
.
patch_size
merge_size
=
vision_config
.
spatial_merge_size
temporal_patch_size
=
vision_config
.
temporal_patch_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
merge_size
,
min_pixels
=
image_processor
.
size
[
"shortest_edge"
],
max_pixels
=
image_processor
.
size
[
"longest_edge"
],
min_pixels
=
size
[
"shortest_edge"
],
max_pixels
=
size
[
"longest_edge"
],
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
else
:
...
...
@@ -873,13 +874,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
:
Qwen2VLImageProcessor
|
None
,
image_processor
:
Qwen2VLImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_image_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
1
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_image_tokens
...
...
@@ -889,13 +892,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
,
image_processor
:
Qwen2VLImageProcessor
|
None
,
image_processor
:
Qwen2VLImageProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
int
:
_
,
num_video_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
num_frames
,
image_processor
=
image_processor
,
mm_kwargs
=
mm_kwargs
,
)
return
num_video_tokens
...
...
@@ -941,15 +946,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return
ImageSize
(
width
=
unit
*
width_factor
,
height
=
unit
*
height_factor
)
def
get_max_image_tokens
(
self
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
def
_get_max_video_frames
(
self
,
max_tokens
:
int
,
start_num_frames
:
int
=
1
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
num_frames
=
start_num_frames
...
...
@@ -960,7 +968,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
next_num_frames
,
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
if
next_max_tokens
>
max_tokens
:
...
...
@@ -990,13 +999,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
int
:
image_processor
=
self
.
get_image_processor
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
),
image_processor
=
None
,
image_processor
=
image_processor
,
mm_kwargs
=
{},
)
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
372b2e76
...
...
@@ -642,13 +642,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height
:
int
,
num_frames
:
int
=
2
,
do_resize
:
bool
=
True
,
image_processor
:
Qwen2VLImageProcessorFast
|
Qwen3VLVideoProcessor
|
None
,
image_processor
:
Qwen2VLImageProcessorFast
|
Qwen3VLVideoProcessor
,
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
and
num_frames
>
1
:
image_processor
=
self
.
get_video_processor
()
elif
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
is_video
=
isinstance
(
image_processor
,
Qwen3VLVideoProcessor
)
hf_config
=
self
.
get_hf_config
()
...
...
@@ -657,6 +653,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
merge_size
=
vision_config
.
spatial_merge_size
temporal_patch_size
=
vision_config
.
temporal_patch_size
mm_kwargs
=
self
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
size
=
mm_kwargs
.
get
(
"size"
,
image_processor
.
size
)
if
do_resize
:
if
is_video
:
smart_resize
=
video_smart_resize
...
...
@@ -667,12 +666,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
else
:
smart_resize
=
image_smart_resize
extra_kwargs
=
{}
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
merge_size
,
min_pixels
=
image_processor
.
size
[
"shortest_edge"
],
max_pixels
=
image_processor
.
size
[
"longest_edge"
],
min_pixels
=
size
[
"shortest_edge"
],
max_pixels
=
size
[
"longest_edge"
],
**
extra_kwargs
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
...
...
@@ -720,7 +720,8 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
2
,
image_processor
=
None
,
image_processor
=
video_processor
,
mm_kwargs
=
{},
)
return
num_video_soft_tokens
...
...
@@ -846,6 +847,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
image_height
=
target_video_height
,
num_frames
=
target_num_frames
,
image_processor
=
video_processor
,
mm_kwargs
=
{},
)
# NOTE: we need to do this check here since Qwen3-VL resizes video
# frames depending on how many frames there are.
...
...
vllm/model_executor/models/skyworkr1v.py
View file @
372b2e76
...
...
@@ -487,11 +487,8 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
SkyworkR1VProcessor
|
None
,
processor
:
SkyworkR1VProcessor
,
)
->
int
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
return
processor
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
...
...
vllm/model_executor/models/smolvlm.py
View file @
372b2e76
...
...
@@ -16,9 +16,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
SmolVLMProcessor
:
return
self
.
ctx
.
get_hf_processor
(
SmolVLMProcessor
,
**
kwargs
)
def
_get_image_token
(
self
,
processor
:
SmolVLMProcessor
|
None
)
->
tuple
[
str
,
str
]:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
def
_get_image_token
(
self
,
processor
:
SmolVLMProcessor
)
->
tuple
[
str
,
str
,
str
]:
image_token
=
processor
.
image_token
fake_image_token
=
processor
.
fake_image_token
global_image_token
=
processor
.
global_image_token
...
...
vllm/multimodal/processing/context.py
View file @
372b2e76
...
...
@@ -409,6 +409,10 @@ class InputProcessingContext:
return
json_map_leaves
(
_postprocess_one
,
output
)
def
get_merged_mm_kwargs
(
self
,
kwargs
:
Mapping
[
str
,
object
]):
mm_config
=
self
.
model_config
.
get_multimodal_config
()
return
mm_config
.
merge_mm_processor_kwargs
(
kwargs
)
def
call_hf_processor
(
self
,
hf_processor
:
ProcessorMixin
,
...
...
@@ -424,8 +428,7 @@ class InputProcessingContext:
"""
assert
callable
(
hf_processor
)
mm_config
=
self
.
model_config
.
get_multimodal_config
()
merged_kwargs
=
mm_config
.
merge_mm_processor_kwargs
(
kwargs
)
merged_kwargs
=
self
.
get_merged_mm_kwargs
(
kwargs
)
allowed_kwargs
=
get_allowed_kwarg_only_overrides
(
hf_processor
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment