Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
304dcdf5
Unverified
Commit
304dcdf5
authored
Jul 27, 2025
by
Benji Beck
Committed by
GitHub
Jul 27, 2025
Browse files
Migrate GLMVImagePixelInputs to TensorSchema (#21679)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
88e46c7c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
23 deletions
+20
-23
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+20
-23
No files found.
vllm/model_executor/models/glm4v.py
View file @
304dcdf5
...
...
@@ -6,7 +6,7 @@
"""Inference-only CogAgent model compatible with THUDM weights."""
from
argparse
import
Namespace
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
from
torch
import
nn
...
...
@@ -38,6 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.chatglm
import
ChatGLMBaseModel
,
ChatGLMModel
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
...
@@ -45,10 +46,16 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from
.utils
import
flatten_bn
,
merge_multimodal_embeddings
class
GLMVImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class
GLMVImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height of image
- w: Width of image
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"b"
,
3
,
"h"
,
"w"
)]
class
EVA2CLIPPatchEmbedding
(
nn
.
Module
):
...
...
@@ -562,19 +569,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
self
.
transformer
:
GLM4VModel
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
[
"image_size"
]
expected_dims
=
(
3
,
h
,
w
)
actual_dims
=
tuple
(
data
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"batch_size"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
f
"The expected shape of pixel values is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
data
.
shape
)
}
."
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
GLMVImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
@@ -584,11 +578,14 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
GLMVImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)),
)
expected_h
=
expected_w
=
self
.
config
.
vision_config
[
"image_size"
]
return
GLMVImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
})
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment