Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f8e9521
Unverified
Commit
3f8e9521
authored
Aug 02, 2025
by
Isotr0py
Committed by
GitHub
Aug 01, 2025
Browse files
[Bugfix] Fix glm4.1v video inference issue (#22067)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
326a1b00
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
6 deletions
+53
-6
tests/models/multimodal/processing/test_glm4_1v.py
tests/models/multimodal/processing/test_glm4_1v.py
+51
-0
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+2
-6
No files found.
tests/models/multimodal/processing/test_glm4_1v.py
0 → 100644
View file @
3f8e9521
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
vllm.assets.video
import
VideoAsset
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"THUDM/GLM-4.1V-9B-Thinking"
])
@
pytest
.
mark
.
parametrize
(
"expected_toks_per_frame"
,
[
299
])
@
pytest
.
mark
.
parametrize
(
"num_frames"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"fps, expected_grid_t"
,
[(
1
,
5
),
(
2
,
10
)])
def
test_processor_override
(
model_id
:
str
,
expected_toks_per_frame
:
int
,
expected_grid_t
:
int
,
fps
:
int
,
num_frames
:
int
,
):
"""Ensure GLM4vMultiModalProcessor can handle video frames properly."""
ctx
=
build_model_context
(
model_id
,
mm_processor_kwargs
=
None
,
limit_mm_per_prompt
=
{
"video"
:
1
},
)
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
tokenizer
=
processor
.
info
.
get_tokenizer
()
hf_processor_mm_kwargs
=
{
"fps"
:
fps
}
# Build the image str / prompt based on the number of images we pass
video_assets
=
VideoAsset
(
name
=
"baby_reading"
,
num_frames
=
num_frames
)
prompt
=
"<|begin_of_video|><|video|><|end_of_video|>"
video
,
metadata
=
video_assets
.
np_ndarrays
,
video_assets
.
metadata
metadata
[
"fps"
]
=
fps
mm_data
=
{
"video"
:
[(
video
,
metadata
)]}
processed_inputs
=
processor
.
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
)
# Ensure we have the right number of placeholders per num_crops size
hf_processor
=
processor
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
video_token_id
=
tokenizer
.
convert_tokens_to_ids
(
hf_processor
.
video_token
)
video_tok_count
=
processed_inputs
[
"prompt_token_ids"
].
count
(
video_token_id
)
grid_t
,
_
,
_
=
processed_inputs
[
"mm_kwargs"
][
"video_grid_thw"
][
0
]
assert
grid_t
==
expected_grid_t
assert
video_tok_count
==
expected_toks_per_frame
*
grid_t
vllm/model_executor/models/glm4_1v.py
View file @
3f8e9521
...
...
@@ -937,7 +937,7 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
total_frames
:
int
)
->
list
[
int
]:
video_processor
=
self
.
get_video_processor
()
video_fps
=
metadata
.
get
(
"fps"
,
2.0
)
video_fps
=
metadata
.
get
(
"fps"
,
video_processor
.
fps
)
meta_frames
=
metadata
.
get
(
"total_num_frames"
,
total_frames
)
max_frame_idx
=
meta_frames
-
1
duration
=
metadata
.
get
(
"duration"
,
...
...
@@ -1120,11 +1120,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
video_placeholder
,
)
grid_t
=
len
(
video_outputs
[
"video_grid_thw"
])
_
,
grid_h
,
grid_w
=
video_outputs
[
"video_grid_thw"
][
0
]
grid_thw
=
torch
.
tensor
([[
grid_t
,
grid_h
,
grid_w
]])
video_grid_thw_lst
.
append
(
grid_thw
)
video_grid_thw_lst
.
append
(
video_outputs
[
"video_grid_thw"
])
pixel_values_videos_lst
.
append
(
video_outputs
[
"pixel_values_videos"
])
video_outputs
=
dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment