Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a70d0bd0
Unverified
Commit
a70d0bd0
authored
Aug 19, 2025
by
Benji Beck
Committed by
GitHub
Aug 19, 2025
Browse files
Migrate LlavaOnevisionMultiInputs to TensorSchema (#21844)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
24f4d1a2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
93 deletions
+56
-93
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+56
-93
No files found.
vllm/model_executor/models/llava_onevision.py
View file @
a70d0bd0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Final
,
Literal
,
Optional
,
Protocol
,
TypedDict
,
Union
from
typing
import
Annotated
,
Final
,
Literal
,
Optional
,
Protocol
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig,
...
@@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig,
LlavaOnevisionProcessor
)
LlavaOnevisionProcessor
)
from
transformers.models.llava_onevision.modeling_llava_onevision
import
(
from
transformers.models.llava_onevision.modeling_llava_onevision
import
(
get_anyres_image_grid_shape
,
unpad_image
)
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
...
@@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems
,
VideoProcessorItems
)
VideoEmbeddingItems
,
VideoProcessorItems
)
from
vllm.multimodal.processing
import
PromptReplacement
,
PromptUpdate
from
vllm.multimodal.processing
import
PromptReplacement
,
PromptUpdate
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.clip
import
CLIPVisionModel
from
.clip
import
CLIPVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
...
@@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_MAX_FRAMES_PER_VIDEO
=
16
_MAX_FRAMES_PER_VIDEO
=
16
class
LlavaOnevisionVideoPixelInputs
(
TypedDict
):
class
LlavaOnevisionVideoPixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of videos
- f: Number of frames
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_videos` may be different for each batch, and 'num_frames'
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
list instead of a batched tensor.
"""
"""
type
:
Literal
[
"pixel_values_videos"
]
=
"pixel_values_videos"
pixel_values_videos
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"f"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"f"
}),
]
class
LlavaOnevisionImagePixelInputs
(
TypedDict
):
class
LlavaOnevisionImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
Shape:
Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
- bn: Batch size * number of images
- np: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image,
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
in which case the data is passed as a list instead of a batched tensor.
"""
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
image_sizes
:
NotRequired
[
torch
.
Tensor
]
pixel_values
:
Annotated
[
"""
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
Shape: `(batch_size * num_images, 2)`
TensorShape
(
"bn"
,
"np"
,
3
,
"h"
,
"w"
),
]
This should be in `(height, width)` format.
"""
image_sizes
:
Annotated
[
Optional
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
2
)]
class
LlavaOnevisionImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class
LlavaOnevisionImageEmbeddingInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
"""
type
:
Literal
[
"image_embeds"
]
=
"image_embeds"
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"ifs"
,
"hs"
),
]
LlavaOnevisionImageInputs
=
Union
[
LlavaOnevisionImagePixelInputs
,
LlavaOnevisionImageInputs
=
Union
[
LlavaOnevisionImagePixelInputs
,
...
@@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
model
.
make_empty_intermediate_tensors
)
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
expected_dims
=
(
2
,
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
f
"The expected shape of image sizes per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_validate_image_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaOnevisionImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaOnevisionImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return
LlavaOnevisionImagePixelInputs
(
return
LlavaOnevisionImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_image_pixel_values
(
pixel_values
=
flatten_bn
(
pixel_values
),
flatten_bn
(
pixel_values
)),
image_sizes
=
flatten_bn
(
image_sizes
,
concat
=
True
),
image_sizes
=
self
.
_validate_image_sizes
(
resolve_bindings
=
{
flatten_bn
(
image_sizes
,
concat
=
True
)),
"h"
:
self
.
config
.
vision_config
.
image_size
,
)
"w"
:
self
.
config
.
vision_config
.
image_size
})
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
...
@@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_validate_video_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
[
2
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_frames"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values in each video frame "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_video_input
(
def
_parse_and_validate_video_input
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaOnevisionVideoPixelInputs
]:
**
kwargs
:
object
)
->
Optional
[
LlavaOnevisionVideoPixelInputs
]:
...
@@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return
LlavaOnevisionVideoPixelInputs
(
return
LlavaOnevisionVideoPixelInputs
(
type
=
"pixel_values_videos"
,
type
=
"pixel_values_videos"
,
pixel_values_videos
=
flatten_bn
(
pixel_values_videos
),
pixel_values_videos
=
flatten_bn
(
pixel_values_videos
),
)
resolve_bindings
=
{
"h"
:
self
.
config
.
vision_config
.
image_size
,
"w"
:
self
.
config
.
vision_config
.
image_size
})
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
mm_input_by_modality
=
{}
mm_input_by_modality
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment