Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8452946c
Unverified
Commit
8452946c
authored
Jul 02, 2025
by
Kwai-Keye
Committed by
GitHub
Jul 01, 2025
Browse files
[Model][VLM] Support Keye-VL-8B-Preview (#20126)
Signed-off-by:
Kwai-Keye
<
Keye@kuaishou.com
>
parent
2e7cbf2d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1801 additions
and
2 deletions
+1801
-2
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+32
-0
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+38
-0
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-2
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+1725
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/models/supported_models.md
View file @
8452946c
...
@@ -559,6 +559,7 @@ Specified using `--task generate`.
...
@@ -559,6 +559,7 @@ Specified using `--task generate`.
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎
\*
|
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎
\*
|
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
etc. | ✅︎ | | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
etc. | ✅︎ | | ✅︎ |
|
`InternVLChatModel`
| InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I
<sup>
E+
</sup>
+ (V
<sup>
E+
</sup>
) |
`OpenGVLab/InternVL3-9B`
,
`OpenGVLab/InternVideo2_5_Chat_8B`
,
`OpenGVLab/InternVL2_5-4B`
,
`OpenGVLab/Mono-InternVL-2B`
,
`OpenGVLab/InternVL2-4B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`InternVLChatModel`
| InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I
<sup>
E+
</sup>
+ (V
<sup>
E+
</sup>
) |
`OpenGVLab/InternVL3-9B`
,
`OpenGVLab/InternVideo2_5_Chat_8B`
,
`OpenGVLab/InternVL2_5-4B`
,
`OpenGVLab/Mono-InternVL-2B`
,
`OpenGVLab/InternVL2-4B`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`KeyeForConditionalGeneration`
| Keye-VL-8B-Preview | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Kwai-Keye/Keye-VL-8B-Preview`
| | | ✅︎ |
|
`KimiVLForConditionalGeneration`
| Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I
<sup>
+
</sup>
|
`moonshotai/Kimi-VL-A3B-Instruct`
,
`moonshotai/Kimi-VL-A3B-Thinking`
| | | ✅︎ |
|
`KimiVLForConditionalGeneration`
| Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I
<sup>
+
</sup>
|
`moonshotai/Kimi-VL-A3B-Instruct`
,
`moonshotai/Kimi-VL-A3B-Thinking`
| | | ✅︎ |
|
`Llama4ForConditionalGeneration`
| Llama 4 | T + I
<sup>
+
</sup>
|
`meta-llama/Llama-4-Scout-17B-16E-Instruct`
,
`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`
,
`meta-llama/Llama-4-Maverick-17B-128E-Instruct`
, etc. | | ✅︎ | ✅︎ |
|
`Llama4ForConditionalGeneration`
| Llama 4 | T + I
<sup>
+
</sup>
|
`meta-llama/Llama-4-Scout-17B-16E-Instruct`
,
`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`
,
`meta-llama/Llama-4-Maverick-17B-128E-Instruct`
, etc. | | ✅︎ | ✅︎ |
|
`LlavaForConditionalGeneration`
| LLaVA-1.5 | T + I
<sup>
E+
</sup>
|
`llava-hf/llava-1.5-7b-hf`
,
`TIGER-Lab/Mantis-8B-siglip-llama3`
(see note), etc. | | ✅︎ | ✅︎ |
|
`LlavaForConditionalGeneration`
| LLaVA-1.5 | T + I
<sup>
E+
</sup>
|
`llava-hf/llava-1.5-7b-hf`
,
`TIGER-Lab/Mantis-8B-siglip-llama3`
(see note), etc. | | ✅︎ | ✅︎ |
...
...
examples/offline_inference/vision_language.py
View file @
8452946c
...
@@ -429,6 +429,37 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
...
@@ -429,6 +429,37 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
)
)
# Keye-VL
def
run_keye_vl
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
model_name
=
"Kwai-Keye/Keye-VL-8B-Preview"
engine_args
=
EngineArgs
(
model
=
model_name
,
max_model_len
=
8192
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
modality
:
1
},
)
if
modality
==
"image"
:
placeholder
=
"<|image_pad|>"
elif
modality
==
"video"
:
placeholder
=
"<|video_pad|>"
prompts
=
[
(
f
"<|im_start|>user
\n
<|vision_start|>
{
placeholder
}
<|vision_end|>"
f
"
{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
for
question
in
questions
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompts
=
prompts
,
)
# Kimi-VL
# Kimi-VL
def
run_kimi_vl
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
def
run_kimi_vl
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
assert
modality
==
"image"
...
@@ -1154,6 +1185,7 @@ model_example_map = {
...
@@ -1154,6 +1185,7 @@ model_example_map = {
"h2ovl_chat"
:
run_h2ovl
,
"h2ovl_chat"
:
run_h2ovl
,
"idefics3"
:
run_idefics3
,
"idefics3"
:
run_idefics3
,
"internvl_chat"
:
run_internvl
,
"internvl_chat"
:
run_internvl
,
"keye_vl"
:
run_keye_vl
,
"kimi_vl"
:
run_kimi_vl
,
"kimi_vl"
:
run_kimi_vl
,
"llava"
:
run_llava
,
"llava"
:
run_llava
,
"llava-next"
:
run_llava_next
,
"llava-next"
:
run_llava_next
,
...
...
examples/offline_inference/vision_language_multi_image.py
View file @
8452946c
...
@@ -423,6 +423,43 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
...
@@ -423,6 +423,43 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
)
)
def
load_keye_vl
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
model_name
=
"Kwai-Keye/Keye-VL-8B-Preview"
engine_args
=
EngineArgs
(
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
8192
,
max_num_seqs
=
5
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
)
placeholders
=
[{
"type"
:
"image"
,
"image"
:
url
}
for
url
in
image_urls
]
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
*
placeholders
,
{
"type"
:
"text"
,
"text"
:
question
},
],
},
]
processor
=
AutoProcessor
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
prompt
=
processor
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompt
=
prompt
,
image_data
=
image_data
,
)
def
load_kimi_vl
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
def
load_kimi_vl
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
model_name
=
"moonshotai/Kimi-VL-A3B-Instruct"
model_name
=
"moonshotai/Kimi-VL-A3B-Instruct"
...
@@ -862,6 +899,7 @@ model_example_map = {
...
@@ -862,6 +899,7 @@ model_example_map = {
"h2ovl_chat"
:
load_h2ovl
,
"h2ovl_chat"
:
load_h2ovl
,
"idefics3"
:
load_idefics3
,
"idefics3"
:
load_idefics3
,
"internvl_chat"
:
load_internvl
,
"internvl_chat"
:
load_internvl
,
"keye_vl"
:
load_keye_vl
,
"kimi_vl"
:
load_kimi_vl
,
"kimi_vl"
:
load_kimi_vl
,
"llava"
:
load_llava
,
"llava"
:
load_llava
,
"llava-next"
:
load_llava_next
,
"llava-next"
:
load_llava_next
,
...
...
tests/models/registry.py
View file @
8452946c
...
@@ -351,6 +351,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -351,6 +351,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
,
# noqa: E501
"Idefics3ForConditionalGeneration"
:
_HfExamplesInfo
(
"HuggingFaceM4/Idefics3-8B-Llama3"
,
# noqa: E501
{
"tiny"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
}),
# noqa: E501
{
"tiny"
:
"HuggingFaceTB/SmolVLM-256M-Instruct"
}),
# noqa: E501
"KeyeForConditionalGeneration"
:
_HfExamplesInfo
(
"Kwai-Keye/Keye-VL-8B-Preview"
,
# noqa: E501
trust_remote_code
=
True
),
"KimiVLForConditionalGeneration"
:
_HfExamplesInfo
(
"moonshotai/Kimi-VL-A3B-Instruct"
,
# noqa: E501
"KimiVLForConditionalGeneration"
:
_HfExamplesInfo
(
"moonshotai/Kimi-VL-A3B-Instruct"
,
# noqa: E501
extras
=
{
"thinking"
:
"moonshotai/Kimi-VL-A3B-Thinking"
},
# noqa: E501
extras
=
{
"thinking"
:
"moonshotai/Kimi-VL-A3B-Thinking"
},
# noqa: E501
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
...
vllm/entrypoints/chat_utils.py
View file @
8452946c
...
@@ -540,7 +540,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -540,7 +540,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<image>"
return
"<image>"
if
model_type
in
(
"mllama"
,
"llama4"
):
if
model_type
in
(
"mllama"
,
"llama4"
):
return
"<|image|>"
return
"<|image|>"
if
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
):
if
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
,
"keye"
,
"Keye"
):
return
"<|vision_start|><|image_pad|><|vision_end|>"
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
model_type
==
"qwen2_5_omni"
:
if
model_type
==
"qwen2_5_omni"
:
return
"<|vision_start|><|IMAGE|><|vision_end|>"
return
"<|vision_start|><|IMAGE|><|vision_end|>"
...
@@ -570,7 +570,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -570,7 +570,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return
"<video>"
return
"<video>"
if
model_type
==
"glm4v"
:
if
model_type
==
"glm4v"
:
return
"<|begin_of_video|><|video|><|end_of_video|>"
return
"<|begin_of_video|><|video|><|end_of_video|>"
if
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
):
if
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
,
"keye"
,
"Keye"
):
return
"<|vision_start|><|video_pad|><|vision_end|>"
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
model_type
==
"qwen2_5_omni"
:
if
model_type
==
"qwen2_5_omni"
:
return
"<|vision_start|><|VIDEO|><|vision_end|>"
return
"<|vision_start|><|VIDEO|><|vision_end|>"
...
...
vllm/model_executor/models/keye.py
0 → 100644
View file @
8452946c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
transformers
import
PretrainedConfig
from
transformers.activations
import
GELUActivation
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutputWithPooling
)
from
transformers.utils
import
torch_int
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageSize
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
(
cached_image_processor_from_config
)
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
from
.siglip
import
SiglipMLP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
is_pp_missing_parameter
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vit_attn_backend
logger
=
init_logger
(
__name__
)
_MAX_FRAMES_PER_VIDEO
=
16
_MAX_IMAGE_SIZE
=
9999999
def
smart_resize
(
height
:
int
,
width
:
int
,
factor
:
int
=
28
,
min_pixels
:
int
=
28
*
28
*
130
,
max_pixels
:
int
=
28
*
28
*
1280
,
):
if
height
<
factor
:
logger
.
warning
(
"smart_resize: height=%s < factor=%s, reset height=factor"
,
height
,
factor
,
)
width
=
round
((
width
*
factor
)
/
height
)
height
=
factor
if
width
<
factor
:
logger
.
warning
(
"smart_resize: width=%s < factor=%s, reset width=factor"
,
width
,
factor
,
)
height
=
round
((
height
*
factor
)
/
width
)
width
=
factor
if
max
(
height
,
width
)
/
min
(
height
,
width
)
>
200
:
raise
ValueError
(
"absolute aspect ratio must be smaller than 200, got "
"{max(height, width) / min(height, width)}"
)
h_bar
=
round
(
height
/
factor
)
*
factor
w_bar
=
round
(
width
/
factor
)
*
factor
if
h_bar
*
w_bar
>
max_pixels
:
beta
=
math
.
sqrt
((
height
*
width
)
/
max_pixels
)
h_bar
=
math
.
floor
(
height
/
beta
/
factor
)
*
factor
w_bar
=
math
.
floor
(
width
/
beta
/
factor
)
*
factor
elif
h_bar
*
w_bar
<
min_pixels
:
beta
=
math
.
sqrt
(
min_pixels
/
(
height
*
width
))
h_bar
=
math
.
ceil
(
height
*
beta
/
factor
)
*
factor
w_bar
=
math
.
ceil
(
width
*
beta
/
factor
)
*
factor
return
h_bar
,
w_bar
class
KeyeImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw
:
torch
.
Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class
KeyeImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
image_embeds
:
torch
.
Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
"""
image_grid_thw
:
torch
.
Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
KeyeImageInputs
=
Union
[
KeyeImagePixelInputs
,
KeyeImageEmbeddingInputs
]
class
KeyeVideoPixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
torch
.
Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw
:
torch
.
Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class
KeyeVideoEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"video_embeds"
]
video_embeds
:
torch
.
Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the videos.
- `hidden_size` must match the hidden size of language model backbone.
"""
video_grid_thw
:
torch
.
Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
KeyeVideoInputs
=
Union
[
KeyeVideoPixelInputs
,
KeyeVideoEmbeddingInputs
]
class
KeyeVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
self
.
cache_position_embedding
=
dict
()
self
.
cache_position_count
=
dict
()
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
packing_position_embedding
=
nn
.
Embedding
(
32768
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)),
persistent
=
False
,
)
def
interpolate_pos_encoding
(
self
,
embeddings
:
torch
.
Tensor
,
height
:
int
,
width
:
int
,
is_after_patchify
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_positions
=
self
.
position_embedding
.
weight
.
shape
[
0
]
patch_pos_embed
=
self
.
position_embedding
.
weight
.
unsqueeze
(
0
)
dim
=
embeddings
.
shape
[
-
1
]
if
is_after_patchify
:
new_height
=
height
new_width
=
width
else
:
new_height
=
height
//
self
.
patch_size
new_width
=
width
//
self
.
patch_size
sqrt_num_positions
=
torch_int
(
num_positions
**
0.5
)
patch_pos_embed
=
patch_pos_embed
.
reshape
(
1
,
sqrt_num_positions
,
sqrt_num_positions
,
dim
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
3
,
1
,
2
)
patch_pos_embed
=
nn
.
functional
.
interpolate
(
patch_pos_embed
,
size
=
(
new_height
,
new_width
),
mode
=
"bilinear"
,
align_corners
=
False
,
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
2
,
3
,
1
).
view
(
1
,
-
1
,
dim
)
return
patch_pos_embed
def
fetch_position_embedding_lfu_cache
(
self
,
embeddings
,
h
,
w
,
max_cache
:
int
=
20
):
grid
=
(
h
,
w
)
if
grid
in
self
.
cache_position_embedding
:
self
.
cache_position_count
[
grid
]
+=
1
return
self
.
cache_position_embedding
[
grid
]
if
len
(
self
.
cache_position_embedding
)
>=
max_cache
:
min_hit_grid
=
min
(
self
.
cache_position_count
,
key
=
self
.
cache_position_count
.
get
,
)
self
.
cache_position_count
.
pop
(
min_hit_grid
)
self
.
cache_position_embedding
.
pop
(
min_hit_grid
)
position_embedding
=
self
.
interpolate_pos_encoding
(
embeddings
,
h
,
w
,
True
)
self
.
cache_position_count
[
grid
]
=
1
self
.
cache_position_embedding
[
grid
]
=
position_embedding
return
position_embedding
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
image_grid_thw
:
Optional
[
list
[
Union
[
tuple
[
int
,
int
,
int
],
list
[
tuple
[
int
,
int
,
int
]],
]]]
=
None
,
interpolate_pos_encoding
=
False
,
)
->
torch
.
Tensor
:
if
pixel_values
.
dim
()
==
4
:
pixel_values
=
pixel_values
.
unsqueeze
(
0
)
if
pixel_values
.
dim
()
==
5
:
if
position_ids
is
None
:
raise
ValueError
(
"position_ids cannot be None when pixel_values.dim() is 5."
)
(
batch_size
,
squence_len
,
channel
,
height
,
width
,
)
=
pixel_values
.
shape
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
pixel_values
=
rearrange
(
pixel_values
,
"b l c h w -> (b l) c h w"
)
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
embeddings
=
patch_embeds
.
flatten
(
-
2
).
squeeze
(
-
1
)
if
interpolate_pos_encoding
and
image_grid_thw
is
not
None
:
start
=
0
tmp_embeddings
=
list
()
for
image_grid
in
image_grid_thw
:
t
,
h
,
w
=
image_grid
end
=
start
+
t
*
h
*
w
image_embeddings
=
embeddings
[
start
:
end
,
:]
position_embedding
=
(
self
.
interpolate_pos_encoding
(
image_embeddings
,
h
,
w
,
True
).
squeeze
(
0
).
repeat
(
t
,
1
))
image_embeddings
=
image_embeddings
+
position_embedding
tmp_embeddings
.
append
(
image_embeddings
)
start
=
end
embeddings
=
torch
.
concat
(
tmp_embeddings
,
dim
=
0
).
unsqueeze
(
0
)
else
:
embeddings
=
embeddings
+
self
.
packing_position_embedding
(
position_ids
)
return
embeddings
else
:
raise
ValueError
(
"Unsupported pixel_values dimension:"
f
"
{
pixel_values
.
dim
()
}
. Expected 4 or 5."
)
def
apply_rotary_pos_emb_flashatt
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos
=
cos
.
chunk
(
2
,
dim
=-
1
)[
0
].
contiguous
()
sin
=
sin
.
chunk
(
2
,
dim
=-
1
)[
0
].
contiguous
()
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
q_embed
=
apply_rotary_emb
(
q
.
float
(),
cos
.
float
(),
sin
.
float
()).
type_as
(
q
)
k_embed
=
apply_rotary_emb
(
k
.
float
(),
cos
.
float
(),
sin
.
float
()).
type_as
(
k
)
return
q_embed
,
k_embed
class
KeyeSiglipAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You
Need' paper."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
config
.
num_attention_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
config
.
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
hidden_size
,
output_size
=
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
f
"Keye-VL does not support
{
self
.
attn_backend
}
backend now."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
cu_seqlens
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
rope_emb
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
(
[
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
,
)
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
batch_size
=
q
.
shape
[
0
]
if
rope_emb
is
None
:
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
*
k
.
shape
[:
-
1
],
self
.
num_kv_heads
,
self
.
head_dim
,
)
v
=
v
.
view
(
*
v
.
shape
[:
-
1
],
self
.
num_kv_heads
,
self
.
head_dim
,
)
else
:
if
cu_seqlens
is
None
:
raise
ValueError
(
"cu_seqlens cannot be None when rope_emb is not None."
)
cos
,
sin
=
rope_emb
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
*
k
.
shape
[:
-
1
],
self
.
num_kv_heads
,
self
.
head_dim
,
)
q
,
k
=
apply_rotary_pos_emb_flashatt
(
q
,
k
,
cos
,
sin
)
v
=
v
.
view
(
*
v
.
shape
[:
-
1
],
self
.
num_kv_heads
,
self
.
head_dim
,
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
from
flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
causal
=
False
,
softmax_scale
=
self
.
scale
,
)
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
,
device
=
q
.
device
)
context_layer
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
0
,
scale
=
None
)
context_layer
=
rearrange
(
context_layer
,
"b s h d -> b s (h d)"
).
contiguous
()
output
,
_
=
self
.
out_proj
(
context_layer
)
return
output
class
SigLIPRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
self
.
rope_init
()
def
rope_init
(
self
):
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
,
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
return
freqs
class
KeyeSiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Union
[
PretrainedConfig
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
KeyeSiglipAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
cu_seqlens
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
rope_emb
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
,
)
->
tuple
[
torch
.
FloatTensor
]:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
output_attentions
=
output_attentions
,
cu_seqlens
=
cu_seqlens
,
rope_emb
=
rope_emb
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
KeyeSiglipEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
num_heads
=
config
.
num_attention_heads
head_dim
=
embed_dim
//
num_heads
self
.
layers
=
nn
.
ModuleList
([
KeyeSiglipEncoderLayer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
rotary_pos_emb
=
SigLIPRotaryEmbedding
(
head_dim
//
2
)
@
staticmethod
def
flatten_list
(
image_grid_thw
):
tmp_image_grid_thw
=
list
()
for
image_grid
in
image_grid_thw
:
if
isinstance
(
image_grid
,
list
):
tmp_image_grid_thw
.
extend
(
image_grid
)
else
:
tmp_image_grid_thw
.
append
(
image_grid
)
return
tmp_image_grid_thw
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
cu_seqlens
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
image_grid_thw
:
Optional
[
list
[
Union
[
tuple
[
int
,
int
,
int
],
list
[
tuple
[
int
,
int
,
int
]],
]]]
=
None
,
height_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
width_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
use_rope
:
Optional
[
bool
]
=
False
,
window_size
:
Optional
[
bool
]
=
-
1
,
vision_or_text
:
str
=
"vision"
,
)
->
BaseModelOutput
:
device
=
inputs_embeds
.
device
hidden_states
=
inputs_embeds
if
use_rope
is
True
:
flatten_image_grid_thw
=
self
.
flatten_list
(
image_grid_thw
)
if
width_position_ids
is
None
or
height_position_ids
is
None
:
split_hids
=
list
()
split_wids
=
list
()
for
t
,
h
,
w
in
flatten_image_grid_thw
:
image_pids
=
torch
.
arange
(
t
*
h
*
w
,
device
=
device
)
%
(
h
*
w
)
sample_hids
=
image_pids
//
w
sample_wids
=
image_pids
%
w
split_hids
.
append
(
sample_hids
)
split_wids
.
append
(
sample_wids
)
width_position_ids
=
torch
.
concat
(
split_wids
,
dim
=
0
)
height_position_ids
=
torch
.
concat
(
split_hids
,
dim
=
0
)
pids
=
torch
.
stack
(
[
height_position_ids
,
width_position_ids
],
dim
=-
1
,
)
max_grid_size
=
pids
.
max
()
+
1
rope_emb_max_grid
=
self
.
rotary_pos_emb
(
max_grid_size
)
rope_emb
=
rope_emb_max_grid
[
pids
].
flatten
(
1
)
rope_emb
=
rope_emb
.
repeat
(
1
,
2
)
rope_emb
=
(
rope_emb
.
cos
(),
rope_emb
.
sin
())
else
:
rope_emb
=
None
attn_cu_seqlens
=
cu_seqlens
hidden_states
=
inputs_embeds
assert
attention_mask
is
None
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
output_attentions
=
output_attentions
,
cu_seqlens
=
attn_cu_seqlens
,
rope_emb
=
rope_emb
,
)
return
hidden_states
class
KeyeSiglipVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
KeyeVisionEmbeddings
(
config
)
self
.
encoder
=
KeyeSiglipEncoder
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
pixel_values
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
interpolate_pos_encoding
:
Optional
[
bool
]
=
False
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
sample_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
image_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
height_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
width_position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
vision_return_embed_list
:
Optional
[
bool
]
=
False
,
image_grid_thw
:
Optional
[
list
[
Union
[
tuple
[
int
,
int
,
int
],
list
[
tuple
[
int
,
int
,
int
]],
]]]
=
None
,
return_pooler_output
:
Optional
[
bool
]
=
True
,
use_rope
:
Optional
[
bool
]
=
False
,
window_size
:
Optional
[
bool
]
=
-
1
,
)
->
BaseModelOutputWithPooling
:
hidden_states
=
self
.
embeddings
(
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
position_ids
=
position_ids
,
image_grid_thw
=
image_grid_thw
,
)
last_hidden_state
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
attention_mask
=
attention_mask
,
cu_seqlens
=
cu_seqlens
,
image_grid_thw
=
image_grid_thw
,
use_rope
=
use_rope
,
height_position_ids
=
height_position_ids
,
width_position_ids
=
width_position_ids
,
window_size
=
window_size
,
vision_or_text
=
"vision"
,
)
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
sample_hidden_state
=
list
()
if
cu_seqlens
is
None
:
raise
ValueError
(
"cu_seqlens cannot be None for "
"SiglipVisionTransformer output processing."
)
for
i
in
range
(
cu_seqlens
.
shape
[
0
]
-
1
):
start
=
cu_seqlens
[
i
]
end
=
cu_seqlens
[
i
+
1
]
tensor
=
last_hidden_state
[:,
start
:
end
,
:].
squeeze
(
0
)
sample_hidden_state
.
append
(
tensor
)
return
sample_hidden_state
class
KeyeSiglipVisionModel
(
nn
.
Module
):
config_class
=
PretrainedConfig
main_input_name
=
"pixel_values"
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vision_model
=
KeyeSiglipVisionTransformer
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
)
self
.
quant_config
=
quant_config
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
device
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
pixel_values
,
sample_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
interpolate_pos_encoding
:
bool
=
False
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
vision_return_embed_list
:
Optional
[
bool
]
=
False
,
image_grid_thw
:
Optional
[
list
[
Union
[
tuple
[
int
,
int
,
int
],
list
[
tuple
[
int
,
int
,
int
]],
]]]
=
None
,
cu_seqlens
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
return_pooler_output
:
Optional
[
bool
]
=
True
,
use_rope
:
Optional
[
bool
]
=
False
,
window_size
:
Optional
[
bool
]
=
-
1
,
)
->
BaseModelOutputWithPooling
:
return
self
.
vision_model
(
pixel_values
=
pixel_values
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
position_ids
=
position_ids
,
vision_return_embed_list
=
vision_return_embed_list
,
image_grid_thw
=
image_grid_thw
,
sample_indices
=
sample_indices
,
cu_seqlens
=
cu_seqlens
,
return_pooler_output
=
return_pooler_output
,
use_rope
=
use_rope
,
window_size
=
window_size
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"head.attention"
in
name
or
"head.layernorm"
in
name
:
continue
if
"head.mlp"
in
name
or
"head.probe"
in
name
:
continue
if
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
)):
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
,
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
,
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
,
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Projector
(
nn
.
Module
):
def
__init__
(
self
,
text_config
:
PretrainedConfig
,
vision_config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
text_config
=
text_config
self
.
vision_config
=
vision_config
self
.
merge_kernel_size
=
(
2
,
2
)
self
.
hidden_size
=
(
self
.
vision_config
.
hidden_size
*
self
.
merge_kernel_size
[
0
]
*
self
.
merge_kernel_size
[
1
])
self
.
pre_norm
=
torch
.
nn
.
LayerNorm
(
self
.
vision_config
.
hidden_size
,
eps
=
1e-05
)
self
.
act
=
GELUActivation
()
self
.
linear_1
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_1"
,
)
self
.
linear_2
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
text_config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.linear_2"
,
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
,
image_grid_thw
:
list
[
tuple
[
int
,
int
,
int
]],
)
->
torch
.
Tensor
:
m1
,
m2
=
self
.
merge_kernel_size
if
isinstance
(
image_features
,
(
list
,
tuple
)):
processed_features
=
list
()
for
image_feature
,
image_grid
in
zip
(
image_features
,
image_grid_thw
):
image_feature
=
self
.
pre_norm
(
image_feature
)
t
,
h
,
w
=
image_grid
image_feature
=
rearrange
(
image_feature
,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)"
,
t
=
t
,
h
=
h
//
m1
,
p1
=
m1
,
w
=
w
//
m2
,
p2
=
m2
,
)
hidden_states
,
_
=
self
.
linear_1
(
image_feature
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
linear_2
(
hidden_states
)
processed_features
.
append
(
hidden_states
)
return
processed_features
dims
=
image_features
.
shape
[:
-
1
]
dim
=
image_features
.
shape
[
-
1
]
image_features
=
image_features
.
view
(
np
.
prod
(
dims
),
dim
)
hidden_states
=
self
.
pre_norm
(
image_features
).
view
(
-
1
,
self
.
hidden_size
)
hidden_states
=
self
.
linear_1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
.
view
(
*
dims
,
-
1
)
def
_keye_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
],
):
image_grid_thw
=
hf_inputs
.
get
(
"image_grid_thw"
,
torch
.
empty
((
0
,
3
)))
image_grid_sizes
=
image_grid_thw
.
prod
(
-
1
)
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_grid_thw
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
class
KeyeMultiModalDataParser
(
MultiModalDataParser
):
def
_parse_image_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
DictEmbeddingItems
(
data
,
modality
=
"image"
,
required_fields
=
{
"image_embeds"
,
"image_grid_thw"
,
},
fields_factory
=
_keye_field_config
,
)
return
super
().
_parse_image_data
(
data
)
def
_parse_video_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
DictEmbeddingItems
(
data
,
modality
=
"video"
,
required_fields
=
{
"video_embeds"
,
"video_grid_thw"
,
},
fields_factory
=
_keye_field_config
,
)
return
super
().
_parse_video_data
(
data
)
class
KeyeProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
PretrainedConfig
)
def
get_hf_processor
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
return
self
.
ctx
.
get_hf_processor
(
image_processor
=
self
.
get_image_processor
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
,
),
**
kwargs
,
)
def
_get_image_processor_kwargs
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
if
self
.
ctx
.
model_config
.
mm_processor_kwargs
:
kwargs
.
update
(
self
.
ctx
.
model_config
.
mm_processor_kwargs
)
if
min_pixels
is
not
None
:
kwargs
[
"min_pixels"
]
=
min_pixels
if
size
is
None
:
size
=
{
"shortest_edge"
:
min_pixels
}
else
:
size
[
"shortest_edge"
]
=
min_pixels
if
max_pixels
is
not
None
:
kwargs
[
"max_pixels"
]
=
max_pixels
if
size
is
None
:
size
=
{
"longest_edge"
:
max_pixels
}
else
:
size
[
"longest_edge"
]
=
max_pixels
if
size
is
not
None
:
kwargs
[
"size"
]
=
size
return
kwargs
def
get_image_processor
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
return
cached_image_processor_from_config
(
self
.
ctx
.
model_config
,
**
self
.
_get_image_processor_kwargs
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
,
**
kwargs
,
),
)
def
get_supported_mm_limits
(
self
,
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
(),
"video"
:
self
.
get_max_video_tokens
(
seq_len
),
}
def
_get_vision_info
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
=
1
,
do_resize
:
bool
=
True
,
image_processor
,
)
->
tuple
[
ImageSize
,
int
]:
if
image_processor
is
None
:
image_processor
=
self
.
get_image_processor
()
hf_config
=
self
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
patch_size
=
vision_config
.
patch_size
merge_size
=
vision_config
.
spatial_merge_size
temporal_patch_size
=
1
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
image_height
,
width
=
image_width
,
factor
=
patch_size
*
merge_size
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
)
preprocessed_size
=
ImageSize
(
width
=
resized_width
,
height
=
resized_height
)
else
:
preprocessed_size
=
ImageSize
(
width
=
image_width
,
height
=
image_height
)
padded_num_frames
=
num_frames
+
num_frames
%
temporal_patch_size
grid_t
=
max
(
padded_num_frames
//
temporal_patch_size
,
1
)
grid_h
=
preprocessed_size
.
height
//
patch_size
grid_w
=
preprocessed_size
.
width
//
patch_size
num_patches
=
grid_t
*
grid_h
*
grid_w
num_vision_tokens
=
num_patches
//
(
merge_size
**
2
)
return
preprocessed_size
,
num_vision_tokens
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
image_processor
,
)
->
int
:
_
,
num_image_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
image_processor
=
image_processor
,
)
return
num_image_tokens
def
get_num_video_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
num_frames
:
int
,
image_processor
,
)
->
int
:
_
,
num_video_tokens
=
self
.
_get_vision_info
(
image_width
=
image_width
,
image_height
=
image_height
,
num_frames
=
num_frames
,
image_processor
=
image_processor
,
)
return
num_video_tokens
def
get_image_size_with_most_features
(
self
,
)
->
ImageSize
:
max_image_size
,
_
=
self
.
_get_vision_info
(
image_width
=
_MAX_IMAGE_SIZE
,
image_height
=
_MAX_IMAGE_SIZE
,
image_processor
=
None
,
)
return
max_image_size
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
image_processor
=
None
,
)
def
_get_max_video_frames
(
self
,
max_tokens
:
int
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
num_frames
=
0
while
True
:
next_num_frames
=
num_frames
+
1
next_max_tokens
=
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
next_num_frames
,
image_processor
=
None
,
)
if
next_max_tokens
>
max_tokens
:
break
num_frames
=
next_num_frames
return
num_frames
def
get_num_frames_with_most_features
(
self
,
seq_len
:
int
)
->
int
:
mm_config
=
self
.
ctx
.
get_mm_config
()
max_images
=
mm_config
.
get_limit_per_prompt
(
"image"
)
max_videos
=
mm_config
.
get_limit_per_prompt
(
"video"
)
max_image_tokens
=
self
.
get_max_image_tokens
()
*
max_images
max_total_frames
=
self
.
_get_max_video_frames
(
seq_len
-
max_image_tokens
)
max_frames_per_video
=
min
(
max_total_frames
//
max
(
max_videos
,
1
),
_MAX_FRAMES_PER_VIDEO
,
)
return
max
(
max_frames_per_video
,
1
)
def
get_max_video_tokens
(
self
,
seq_len
:
int
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
),
image_processor
=
None
,
)
class
KeyeDummyInputsBuilder
(
BaseDummyInputsBuilder
[
KeyeProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
video_token
:
str
=
hf_processor
.
video_token
return
image_token
*
num_images
+
video_token
*
num_videos
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
target_width
,
target_height
=
(
self
.
info
.
get_image_size_with_most_features
())
target_num_frames
=
self
.
info
.
get_num_frames_with_most_features
(
seq_len
)
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
,
),
"video"
:
self
.
_get_dummy_videos
(
width
=
target_width
,
height
=
target_height
,
num_frames
=
target_num_frames
,
num_videos
=
num_videos
,
),
}
return
mm_data
class
KeyeMultiModalProcessor
(
BaseMultiModalProcessor
[
KeyeProcessingInfo
]):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
KeyeMultiModalDataParser
()
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
),
)
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_processor
=
self
.
info
.
get_image_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
placeholder
=
{
"image"
:
vocab
[
hf_processor
.
image_token
],
"video"
:
vocab
[
hf_processor
.
video_token
],
}
merge_length
=
image_processor
.
merge_size
**
2
def
get_replacement_keye
(
item_idx
:
int
,
modality
:
str
):
grid_thw
=
out_mm_kwargs
[
f
"
{
modality
}
_grid_thw"
][
item_idx
]
assert
isinstance
(
grid_thw
,
torch
.
Tensor
)
num_tokens
=
int
(
grid_thw
.
prod
())
//
merge_length
return
[
placeholder
[
modality
]]
*
num_tokens
return
[
PromptReplacement
(
modality
=
modality
,
target
=
[
placeholder
[
modality
]],
replacement
=
partial
(
get_replacement_keye
,
modality
=
modality
),
)
for
modality
in
(
"image"
,
"video"
)
]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
_keye_field_config
(
hf_inputs
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
KeyeMultiModalProcessor
,
info
=
KeyeProcessingInfo
,
dummy_inputs
=
KeyeDummyInputsBuilder
,
)
class
KeyeForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"lm_head."
:
"language_model.lm_head."
,
"model."
:
"language_model.model."
,
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
:
PretrainedConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
visual
=
KeyeSiglipVisionModel
(
config
.
vision_config
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
self
.
mlp_AR
=
Projector
(
config
,
config
.
vision_config
,
quant_config
=
self
.
_maybe_ignore_quant_config
(
quant_config
),
prefix
=
maybe_prefix
(
prefix
,
"mlp_AR"
),
)
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Qwen3ForCausalLM"
],
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_maybe_ignore_quant_config
(
self
,
quant_config
:
QuantizationConfig
):
if
isinstance
(
quant_config
,
(
GPTQConfig
,
GPTQMarlinConfig
)):
return
None
return
quant_config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
==
5
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
torch
.
concat
(
list
(
mm_input
))
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_grid_thw
=
kwargs
.
pop
(
"image_grid_thw"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
KeyeImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
,
)
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
KeyeImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
image_grid_thw
=
image_grid_thw
,
)
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeVideoInputs
]:
pixel_values_videos
=
kwargs
.
pop
(
"pixel_values_videos"
,
None
)
video_embeds
=
kwargs
.
pop
(
"video_embeds"
,
None
)
video_grid_thw
=
kwargs
.
pop
(
"video_grid_thw"
,
None
)
if
pixel_values_videos
is
None
and
video_embeds
is
None
:
return
None
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
,
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
KeyeVideoPixelInputs
(
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
video_grid_thw
=
video_grid_thw
,
)
if
video_embeds
is
not
None
:
video_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
video_embeds
,
"video embeds"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
if
not
isinstance
(
video_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of video embeddings. "
f
"Got type:
{
type
(
video_embeds
)
}
"
)
return
KeyeVideoEmbeddingInputs
(
type
=
"video_embeds"
,
video_embeds
=
video_embeds
,
video_grid_thw
=
video_grid_thw
,
)
def
_process_image_input
(
self
,
image_input
:
KeyeImageInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
siglip_position_ids
=
list
()
image_grid_hws
=
list
()
sample_indices
=
list
()
cu_seqlens
=
[
0
]
image_grid_thw
=
image_input
[
"image_grid_thw"
]
assert
image_grid_thw
.
ndim
==
2
for
idx
,
thaw
in
enumerate
(
image_grid_thw
):
thw_tuple
=
tuple
(
thaw
.
detach
().
cpu
().
numpy
().
tolist
())
numel
=
np
.
prod
(
thw_tuple
)
image_grid_hws
.
append
(
thw_tuple
)
image_position_ids
=
torch
.
arange
(
numel
)
%
np
.
prod
(
thw_tuple
[
1
:])
siglip_position_ids
.
append
(
image_position_ids
)
sample_indices
.
append
(
torch
.
full
((
numel
,
),
idx
,
dtype
=
torch
.
int64
))
cu_seqlens
.
append
(
cu_seqlens
[
-
1
]
+
numel
)
if
image_input
[
"type"
]
==
"image_embeds"
:
raise
ValueError
(
"Image embeddings are not supported for this processing path."
)
else
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
siglip_position_ids
=
torch
.
concat
(
siglip_position_ids
,
dim
=
0
).
to
(
pixel_values
.
device
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
).
to
(
pixel_values
.
device
)
sample_indices
=
torch
.
concat
(
sample_indices
,
dim
=
0
).
to
(
pixel_values
.
device
)
image_embeds
=
self
.
visual
(
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_hws
,
position_ids
=
siglip_position_ids
,
vision_return_embed_list
=
False
,
interpolate_pos_encoding
=
True
,
sample_indices
=
sample_indices
,
cu_seqlens
=
cu_seqlens
,
use_rope
=
True
,
window_size
=-
1
,
)
image_embeds
=
tuple
(
self
.
mlp_AR
(
image_embeds
,
image_grid_thw
))
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
KeyeVideoInputs
)
->
tuple
[
torch
.
Tensor
,
...]:
siglip_position_ids
=
list
()
video_grid_hws
=
list
()
sample_indices
=
list
()
cu_seqlens
=
[
0
]
video_grid_thw
=
video_input
[
"video_grid_thw"
]
assert
video_grid_thw
.
ndim
==
2
for
idx
,
thaw
in
enumerate
(
video_grid_thw
):
thw_tuple
=
tuple
(
thaw
.
detach
().
cpu
().
numpy
().
tolist
())
numel
=
np
.
prod
(
thw_tuple
)
video_grid_hws
.
append
(
thw_tuple
)
video_position_ids
=
torch
.
arange
(
numel
)
%
np
.
prod
(
thw_tuple
[
1
:])
siglip_position_ids
.
append
(
video_position_ids
)
sample_indices
.
append
(
torch
.
full
((
numel
,
),
idx
,
dtype
=
torch
.
int64
))
cu_seqlens
.
append
(
cu_seqlens
[
-
1
]
+
numel
)
if
video_input
[
"type"
]
==
"video_embeds"
:
raise
ValueError
(
"Video embeddings are not supported for this processing path."
)
else
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
siglip_position_ids
=
torch
.
concat
(
siglip_position_ids
,
dim
=
0
).
to
(
pixel_values_videos
.
device
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
).
to
(
pixel_values_videos
.
device
)
sample_indices
=
torch
.
concat
(
sample_indices
,
dim
=
0
).
to
(
pixel_values_videos
.
device
)
video_embeds
=
self
.
visual
(
pixel_values
=
pixel_values_videos
,
image_grid_thw
=
video_grid_hws
,
position_ids
=
siglip_position_ids
,
vision_return_embed_list
=
True
,
interpolate_pos_encoding
=
True
,
sample_indices
=
sample_indices
,
cu_seqlens
=
cu_seqlens
,
use_rope
=
True
,
window_size
=-
1
,
)
video_embeds
=
tuple
(
self
.
mlp_AR
(
video_embeds
,
video_grid_thw
))
return
video_embeds
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
modalities
=
{}
for
input_key
in
kwargs
:
if
(
input_key
in
(
"pixel_values"
,
"image_embeds"
)
and
"images"
not
in
modalities
):
modalities
[
"images"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
(
input_key
in
(
"pixel_values_videos"
,
"video_embeds"
)
and
"videos"
not
in
modalities
):
modalities
[
"videos"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
return
modalities
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
return
None
multimodal_embeddings
:
tuple
[
torch
.
Tensor
,
...]
=
()
for
modality
in
modalities
:
if
modality
==
"images"
:
image_input
=
modalities
[
"images"
]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
multimodal_embeddings
+=
vision_embeddings
if
modality
==
"videos"
:
video_input
=
modalities
[
"videos"
]
video_embeddings
=
self
.
_process_video_input
(
video_input
)
multimodal_embeddings
+=
video_embeddings
return
multimodal_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
self
.
config
.
image_token_id
,
self
.
config
.
video_token_id
,
],
)
return
inputs_embeds
def
get_input_embeddings_v0
(
self
,
input_ids
:
torch
.
Tensor
,
image_input
:
Optional
[
KeyeImagePixelInputs
]
=
None
,
video_input
:
Optional
[
KeyeVideoPixelInputs
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
if
image_input
is
not
None
:
image_embeds
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
image_embeds
,
placeholder_token_id
=
self
.
config
.
image_token_id
,
)
if
video_input
is
not
None
:
video_embeds
=
self
.
_process_video_input
(
video_input
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeds
,
placeholder_token_id
=
self
.
config
.
video_token_id
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
elif
inputs_embeds
is
None
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
image_input
is
None
and
video_input
is
None
:
inputs_embeds
=
None
else
:
if
uses_mrope
(
self
.
config
):
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embeds
=
self
.
get_input_embeddings_v0
(
input_ids
,
image_input
=
image_input
,
video_input
=
video_input
,
)
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""Get the module prefix in multimodal models."""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"visual."
,
tower_model
=
"mlp_AR."
,
)
vllm/model_executor/models/registry.py
View file @
8452946c
...
@@ -197,6 +197,7 @@ _MULTIMODAL_MODELS = {
...
@@ -197,6 +197,7 @@ _MULTIMODAL_MODELS = {
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"Idefics3ForConditionalGeneration"
:(
"idefics3"
,
"Idefics3ForConditionalGeneration"
),
"Idefics3ForConditionalGeneration"
:(
"idefics3"
,
"Idefics3ForConditionalGeneration"
),
"SmolVLMForConditionalGeneration"
:
(
"smolvlm"
,
"SmolVLMForConditionalGeneration"
),
# noqa: E501
"SmolVLMForConditionalGeneration"
:
(
"smolvlm"
,
"SmolVLMForConditionalGeneration"
),
# noqa: E501
"KeyeForConditionalGeneration"
:
(
"keye"
,
"KeyeForConditionalGeneration"
),
"KimiVLForConditionalGeneration"
:
(
"kimi_vl"
,
"KimiVLForConditionalGeneration"
),
# noqa: E501
"KimiVLForConditionalGeneration"
:
(
"kimi_vl"
,
"KimiVLForConditionalGeneration"
),
# noqa: E501
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment