Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59a85c36
Unverified
Commit
59a85c36
authored
Oct 05, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 05, 2025
Browse files
[Model] Use `merge_by_field_config` for MM models (H-L) (#26230)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
119f0063
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
29 additions
and
161 deletions
+29
-161
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+1
-1
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+9
-23
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+7
-52
vllm/model_executor/models/keye_vl1_5.py
vllm/model_executor/models/keye_vl1_5.py
+5
-51
vllm/model_executor/models/kimi_vl.py
vllm/model_executor/models/kimi_vl.py
+1
-32
vllm/utils/tensor_schema.py
vllm/utils/tensor_schema.py
+6
-2
No files found.
examples/offline_inference/vision_language_multi_image.py
View file @
59a85c36
...
@@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
...
@@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model_name
,
model
=
model_name
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
max_model_len
=
8192
,
max_model_len
=
32768
,
max_num_seqs
=
5
,
max_num_seqs
=
5
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
)
...
...
vllm/model_executor/models/idefics3.py
View file @
59a85c36
...
@@ -53,7 +53,7 @@ from .idefics2_vision_model import (
...
@@ -53,7 +53,7 @@ from .idefics2_vision_model import (
# yapf: enable
# yapf: enable
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
from
.llama
import
LlamaModel
from
.llama
import
LlamaModel
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
class
Idefics3ImagePixelInputs
(
TensorSchema
):
class
Idefics3ImagePixelInputs
(
TensorSchema
):
...
@@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
...
@@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
"""
"""
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
3
,
"h"
,
"w"
)]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
3
,
"h"
,
"w"
)]
pixel_attention_mask
:
torch
.
Tensor
pixel_attention_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
"h"
,
"w"
)]
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
...
@@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
...
@@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
dummy_inputs
=
Idefics3DummyInputsBuilder
)
dummy_inputs
=
Idefics3DummyInputsBuilder
)
class
Idefics3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
class
Idefics3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
):
SupportsLoRA
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
return
None
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Idefics3ImageEmbeddingInputs
(
return
Idefics3ImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
data
=
image_embeds
,
)
)
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_attention_mask
=
kwargs
.
pop
(
"pixel_attention_mask"
)
pixel_attention_mask
=
kwargs
.
pop
(
"pixel_attention_mask"
)
if
not
isinstance
(
pixel_attention_mask
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel_attention_mask. "
f
"Got type:
{
type
(
pixel_attention_mask
)
}
"
)
num_patches
=
kwargs
.
pop
(
"num_patches"
)
num_patches
=
kwargs
.
pop
(
"num_patches"
)
if
not
isinstance
(
num_patches
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
expected_h
=
expected_w
=
self
.
config
.
vision_config
.
image_size
expected_h
=
expected_w
=
self
.
config
.
vision_config
.
image_size
return
Idefics3ImagePixelInputs
(
return
Idefics3ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
pixel_values
=
pixel_values
,
pixel_attention_mask
=
flatten_bn
(
pixel_attention_mask
,
pixel_attention_mask
=
pixel_attention_mask
,
concat
=
True
),
num_patches
=
num_patches
,
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
),
resolve_bindings
=
{
resolve_bindings
=
{
"h"
:
expected_h
,
"h"
:
expected_h
,
"w"
:
expected_w
"w"
:
expected_w
...
...
vllm/model_executor/models/keye.py
View file @
59a85c36
...
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
VideoItem
)
MultiModalKwargsItems
,
VideoItem
)
...
@@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -100,8 +99,7 @@ def smart_resize(
...
@@ -100,8 +99,7 @@ def smart_resize(
class
KeyeImagePixelInputs
(
TensorSchema
):
class
KeyeImagePixelInputs
(
TensorSchema
):
"""
"""
Dimensions:
Dimensions:
- b: Batch size
- bnp: Batch size * Number of patches
- np: Number of patches
- c: Number of channels
- c: Number of channels
- ps: Patch size
- ps: Patch size
- ni: Number of images
- ni: Number of images
...
@@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
...
@@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
pixel_values
:
Annotated
[
torch
.
Tensor
,
torch
.
Tensor
,
TensorShape
(
"b
"
,
"
np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
TensorShape
(
"bnp"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"
b
np"
})]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
...
@@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
...
@@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class
KeyeVideoPixelInputs
(
TensorSchema
):
class
KeyeVideoPixelInputs
(
TensorSchema
):
"""
"""
Dimensions:
Dimensions:
- b: Batch size
- bnp: Batch size * Number of patches
- np: Number of patches
- c: Number of channels
- c: Number of channels
- ps: Patch size
- ps: Patch size
- ni: Number of images
- ni: Number of images
...
@@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
...
@@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values_videos"
]
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Annotated
[
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
torch
.
Tensor
,
TensorShape
(
"b
"
,
"
np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
TensorShape
(
"bnp"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"
b
np"
})]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
...
@@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
...
@@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
class
BaseKeyeModule
(
nn
.
Module
):
class
BaseKeyeModule
(
nn
.
Module
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
prefix
:
str
=
""
)
->
nn
.
Module
:
prefix
:
str
=
""
)
->
nn
.
Module
:
return
Projector
(
text_config
,
vision_config
,
quant_config
,
prefix
)
return
Projector
(
text_config
,
vision_config
,
quant_config
,
prefix
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
NestedTensors
,
name
:
str
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
==
5
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
mm_input
.
reshape
(
-
1
,
mm_input
.
shape
[
-
1
])
elif
is_list_of
(
mm_input
,
torch
.
Tensor
):
if
all
(
p
.
dim
()
==
4
for
p
in
mm_input
)
or
all
(
p
.
dim
()
==
2
for
p
in
mm_input
):
return
mm_input
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return
None
return
None
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
return
KeyeImagePixelInputs
(
return
KeyeImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
...
@@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
return
KeyeImageEmbeddingInputs
(
return
KeyeImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
image_embeds
=
image_embeds
,
...
@@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return
None
return
None
if
pixel_values_videos
is
not
None
:
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
,
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
KeyeVideoPixelInputs
(
return
KeyeVideoPixelInputs
(
type
=
"pixel_values_videos"
,
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
pixel_values_videos
=
pixel_values_videos
,
...
@@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
)
if
video_embeds
is
not
None
:
if
video_embeds
is
not
None
:
video_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
video_embeds
,
"video embeds"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
KeyeVideoEmbeddingInputs
(
return
KeyeVideoEmbeddingInputs
(
type
=
"video_embeds"
,
type
=
"video_embeds"
,
video_embeds
=
video_embeds
,
video_embeds
=
video_embeds
,
...
...
vllm/model_executor/models/keye_vl1_5.py
View file @
59a85c36
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
...
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalFieldConfig
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
VideoItem
)
MultiModalKwargsItems
,
VideoItem
)
...
@@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
...
@@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
class
KeyeVL1_5ImagePixelInputs
(
TensorSchema
):
class
KeyeVL1_5ImagePixelInputs
(
TensorSchema
):
"""
"""
Dimensions:
Dimensions:
- b: Batch size
- bnp: Batch size * Number of patches
- np: Number of patches
- c: Number of channels
- c: Number of channels
- ps: Patch size
- ps: Patch size
- ni: Number of images
- ni: Number of images
...
@@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
...
@@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
pixel_values
:
Annotated
[
pixel_values
:
Annotated
[
torch
.
Tensor
,
torch
.
Tensor
,
TensorShape
(
"np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
TensorShape
(
"
b
np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"
b
np"
})]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
...
@@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
...
@@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
class
KeyeVL1_5VideoPixelInputs
(
TensorSchema
):
class
KeyeVL1_5VideoPixelInputs
(
TensorSchema
):
"""
"""
Dimensions:
Dimensions:
- b: Batch size
- bnp: Batch size * Number of patches
- np: Number of patches
- c: Number of channels
- c: Number of channels
- ps: Patch size
- ps: Patch size
- ni: Number of images
- ni: Number of images
...
@@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
...
@@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values_videos"
]
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Annotated
[
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
torch
.
Tensor
,
TensorShape
(
"np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
TensorShape
(
"
b
np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"
b
np"
})]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
num_frames
:
torch
.
Tensor
num_frames
:
torch
.
Tensor
...
@@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
self
.
merge_size
=
config
.
vision_config
.
spatial_merge_size
self
.
merge_size
=
config
.
vision_config
.
spatial_merge_size
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
NestedTensors
,
expected_dim
:
int
,
name
:
str
):
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
expected_dim
:
return
mm_input
elif
mm_input
.
ndim
==
expected_dim
+
1
:
return
mm_input
.
reshape
(
-
1
,
*
mm_input
.
shape
[
2
:])
else
:
raise
ValueError
(
f
"
{
name
}
should be
{
expected_dim
}
D or "
f
"batched
{
expected_dim
}
D tensor."
f
"Got ndim:
{
mm_input
.
ndim
}
(shape=
{
mm_input
.
shape
}
)"
)
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeVL1_5ImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeVL1_5ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return
None
return
None
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
expected_dim
=
4
,
name
=
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
expected_dim
=
2
,
name
=
"image grid_thw"
)
return
KeyeVL1_5ImagePixelInputs
(
return
KeyeVL1_5ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
...
@@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
expected_dim
=
2
,
name
=
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
expected_dim
=
2
,
name
=
"image grid_thw"
)
return
KeyeVL1_5ImageEmbeddingInputs
(
return
KeyeVL1_5ImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
image_embeds
=
image_embeds
,
...
@@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return
None
return
None
if
pixel_values_videos
is
not
None
:
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
expected_dim
=
4
,
name
=
"video pixel values"
,
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
expected_dim
=
2
,
name
=
"video grid_thw"
)
num_frames
=
self
.
_validate_and_reshape_mm_tensor
(
num_frames
,
expected_dim
=
1
,
name
=
"video num frames"
)
return
KeyeVL1_5VideoPixelInputs
(
return
KeyeVL1_5VideoPixelInputs
(
type
=
"pixel_values_videos"
,
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
pixel_values_videos
=
pixel_values_videos
,
...
@@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
...
@@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
num_frames
=
num_frames
)
num_frames
=
num_frames
)
if
video_embeds
is
not
None
:
if
video_embeds
is
not
None
:
video_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
video_embeds
,
expected_dim
=
2
,
name
=
"video embeds"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
expected_dim
=
2
,
name
=
"video grid_thw"
)
return
KeyeVL1_5VideoEmbeddingInputs
(
type
=
"video_embeds"
,
return
KeyeVL1_5VideoEmbeddingInputs
(
type
=
"video_embeds"
,
video_embeds
=
video_embeds
,
video_embeds
=
video_embeds
,
video_grid_thw
=
video_grid_thw
,
video_grid_thw
=
video_grid_thw
,
...
...
vllm/model_executor/models/kimi_vl.py
View file @
59a85c36
...
@@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
...
@@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
dummy_inputs
=
KimiVLDummyInputsBuilder
)
dummy_inputs
=
KimiVLDummyInputsBuilder
)
class
KimiVLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
class
KimiVLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
SupportsPP
):
merge_by_field_config
=
True
supports_encoder_tp_data
=
True
supports_encoder_tp_data
=
True
...
@@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
config
.
vocab_size
,
logit_scale
)
config
.
vocab_size
,
logit_scale
)
self
.
media_placeholder
:
int
=
self
.
config
.
media_placeholder_token_id
self
.
media_placeholder
:
int
=
self
.
config
.
media_placeholder_token_id
# ref: qwen2_vl.py
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
mm_input
.
reshape
(
-
1
,
mm_input
.
shape
[
-
1
])
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KimiVLImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
KimiVLImageInputs
]:
# image input type must be pixel values now
# image input type must be pixel values now
...
@@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
if
pixel_values
is
None
:
if
pixel_values
is
None
:
return
None
return
None
image_grid_hws
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_hws
,
"image grid hws"
)
# pixel_values may have complex shapes
num_channels
=
3
patch_size
=
self
.
config
.
vision_config
.
patch_size
if
isinstance
(
pixel_values
,
list
):
pixel_values
=
torch
.
cat
([
x
.
reshape
(
-
1
,
num_channels
,
patch_size
,
patch_size
)
for
x
in
pixel_values
])
else
:
pixel_values
=
pixel_values
.
reshape
(
-
1
,
num_channels
,
patch_size
,
patch_size
)
pixel_values
=
pixel_values
.
to
(
self
.
vision_tower
.
dtype
)
return
KimiVLImagePixelInputs
(
return
KimiVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
...
...
vllm/utils/tensor_schema.py
View file @
59a85c36
...
@@ -164,7 +164,9 @@ class TensorSchema:
...
@@ -164,7 +164,9 @@ class TensorSchema:
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
f
"but expected
{
len
(
expected_shape
)
}
"
)
f
"but expected
{
len
(
expected_shape
)
}
. "
f
"Expected shape:
{
expected_shape
}
, "
f
"but got
{
actual_shape
}
"
)
for
i
,
dim
in
enumerate
(
expected_shape
):
for
i
,
dim
in
enumerate
(
expected_shape
):
if
dim
in
dynamic_dims
:
if
dim
in
dynamic_dims
:
...
@@ -172,7 +174,9 @@ class TensorSchema:
...
@@ -172,7 +174,9 @@ class TensorSchema:
elif
isinstance
(
dim
,
int
):
elif
isinstance
(
dim
,
int
):
if
actual_shape
[
i
]
!=
dim
:
if
actual_shape
[
i
]
!=
dim
:
raise
ValueError
(
f
"
{
field_name
}
dim[
{
i
}
] expected "
raise
ValueError
(
f
"
{
field_name
}
dim[
{
i
}
] expected "
f
"
{
dim
}
, got
{
actual_shape
[
i
]
}
"
)
f
"
{
dim
}
, got
{
actual_shape
[
i
]
}
. "
f
"Expected shape:
{
expected_shape
}
, "
f
"but got
{
actual_shape
}
"
)
elif
isinstance
(
dim
,
str
):
elif
isinstance
(
dim
,
str
):
if
dim
in
shape_env
:
if
dim
in
shape_env
:
if
actual_shape
[
i
]
!=
shape_env
[
dim
]:
if
actual_shape
[
i
]
!=
shape_env
[
dim
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment