Unverified Commit 7c8271cd authored by Kwai-Keye's avatar Kwai-Keye Committed by GitHub
Browse files

[Model]: support KeyeVL-1_5-8B (#23838)


Signed-off-by: default avatarwangruitao <wangruitao@kuaishou.com>
Co-authored-by: default avatarwangruitao <wangruitao@kuaishou.com>
parent 3e330fcb
...@@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
......
...@@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
) )
# Keye-VL-1.5
def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1.5-8B"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
trust_remote_code=True,
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Kimi-VL # Kimi-VL
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
...@@ -1648,6 +1679,7 @@ model_example_map = { ...@@ -1648,6 +1679,7 @@ model_example_map = {
"interns1": run_interns1, "interns1": run_interns1,
"internvl_chat": run_internvl, "internvl_chat": run_internvl,
"keye_vl": run_keye_vl, "keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl, "kimi_vl": run_kimi_vl,
"llama4": run_llama4, "llama4": run_llama4,
"llava": run_llava, "llava": run_llava,
......
...@@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct" model_name = "moonshotai/Kimi-VL-A3B-Instruct"
...@@ -1209,6 +1246,7 @@ model_example_map = { ...@@ -1209,6 +1246,7 @@ model_example_map = {
"interns1": load_interns1, "interns1": load_interns1,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
"keye_vl": load_keye_vl, "keye_vl": load_keye_vl,
"keye_vl1_5": load_keye_vl1_5,
"kimi_vl": load_kimi_vl, "kimi_vl": load_kimi_vl,
"llama4": load_llama4, "llama4": load_llama4,
"llava": load_llava, "llava": load_llava,
......
...@@ -293,6 +293,7 @@ def _test_processing_correctness_one( ...@@ -293,6 +293,7 @@ def _test_processing_correctness_one(
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
"OpenGVLab/InternVL3_5-30B-A3B", "OpenGVLab/InternVL3_5-30B-A3B",
"Kwai-Keye/Keye-VL-8B-Preview", "Kwai-Keye/Keye-VL-8B-Preview",
"Kwai-Keye/Keye-VL-1_5-8B",
"moonshotai/Kimi-VL-A3B-Instruct", "moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf", "llava-hf/llava-1.5-7b-hf",
......
...@@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501 "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
......
...@@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len, context_len=context_len,
seq_len=seq_len, seq_len=seq_len,
) )
elif "KeyeVL1_5" in hf_config.model_type:
return cls._keye_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else: else:
return cls._vl_get_input_positions_tensor( return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens, input_tokens=input_tokens,
...@@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item() len(input_tokens)).item()
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _keye_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod @classmethod
def _vl_get_input_positions_tensor( def _vl_get_input_positions_tensor(
cls, cls,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, Optional, Union from typing import Annotated, Any, Literal, Optional, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
...@@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend ...@@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend
logger = init_logger(__name__) logger = init_logger(__name__)
_MAX_FRAMES_PER_VIDEO = 16
_MAX_IMAGE_SIZE = 9999999
def smart_resize( def smart_resize(
height: int, height: int,
width: int, width: int,
factor: int = 28, factor: int,
min_pixels: int = 28 * 28 * 130, min_pixels: int,
max_pixels: int = 28 * 28 * 1280, max_pixels: int,
): ):
if height < factor: if height < factor:
logger.warning( logger.warning(
...@@ -887,9 +885,9 @@ class Projector(nn.Module): ...@@ -887,9 +885,9 @@ class Projector(nn.Module):
def forward( def forward(
self, self,
image_features: torch.Tensor, image_features: Union[torch.Tensor, list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]], image_grid_thw: list[tuple[int, int, int]],
) -> torch.Tensor: ) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)): if isinstance(image_features, (list, tuple)):
processed_features = list() processed_features = list()
...@@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser): ...@@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
class KeyeProcessingInfo(BaseProcessingInfo): class KeyeProcessingInfo(BaseProcessingInfo):
def get_max_image_size(self) -> int:
return 9999999 #_MAX_IMAGE_SIZE
def get_max_frame_per_video(self) -> int:
return 16 #_MAX_FRAMES_PER_VIDEO
def get_image_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
...@@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self, ) -> ImageSize: def get_image_size_with_most_features(self, ) -> ImageSize:
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=_MAX_IMAGE_SIZE, image_width=self.get_max_image_size(),
image_height=_MAX_IMAGE_SIZE, image_height=self.get_max_image_size(),
image_processor=None, image_processor=None,
) )
return max_image_size return max_image_size
...@@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo):
max_image_tokens) max_image_tokens)
max_frames_per_video = min( max_frames_per_video = min(
max_total_frames // max(max_videos, 1), max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO, self.get_max_frame_per_video(),
) )
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
...@@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo):
) )
class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): _I = TypeVar("_I", bound=KeyeProcessingInfo)
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): ...@@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
return mm_data return mm_data
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
...
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
...@@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): ...@@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return _keye_field_config(hf_inputs) return _keye_field_config(hf_inputs)
@MULTIMODAL_REGISTRY.register_processor( class BaseKeyeModule(nn.Module):
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
raise ValueError("Only image or video modality is supported") raise ValueError("Only image or video modality is supported")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config config: PretrainedConfig = vllm_config.model_config.hf_config
...@@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
quant_config=self._maybe_ignore_quant_config(quant_config), quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
self.mlp_AR = Projector(
self.mlp_AR = self._build_projector(
config, config,
config.vision_config, config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config), quant_config=self._maybe_ignore_quant_config(quant_config),
...@@ -1294,102 +1305,16 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1294,102 +1305,16 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): @abstractmethod
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): def _build_projector(self,
return None text_config: PretrainedConfig,
return quant_config vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, prefix: str = "") -> nn.Module:
name: str) -> torch.Tensor: raise ValueError("Need projector")
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
def _process_image_input( def _process_image_input(self,
self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]: image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list() siglip_position_ids = list()
image_grid_hws = list() image_grid_hws = list()
sample_indices = list() sample_indices = list()
...@@ -1434,18 +1359,20 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1434,18 +1359,20 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
return image_embeds return image_embeds
def _process_video_input( def _process_video_embeds(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: self,
video_type: Literal["video_embeds", "pixel_values_videos"],
video_grid_thw: list[torch.Tensor],
pixel_values_videos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
siglip_position_ids = list() siglip_position_ids = list()
video_grid_hws = list() video_grid_hws = list()
sample_indices = list() sample_indices = list()
cu_seqlens = [0] cu_seqlens = [0]
video_grid_thw = video_input["video_grid_thw"]
assert video_grid_thw.ndim == 2 assert video_grid_thw.ndim == 2
for idx, sub_thw in enumerate(video_grid_thw):
for idx, thaw in enumerate(video_grid_thw): thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple) numel = np.prod(thw_tuple)
video_grid_hws.append(thw_tuple) video_grid_hws.append(thw_tuple)
...@@ -1455,12 +1382,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1455,12 +1382,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
dtype=torch.int64)) dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel) cu_seqlens.append(cu_seqlens[-1] + numel)
if video_input["type"] == "video_embeds": if video_type == "video_embeds":
raise ValueError( raise ValueError(
"Video embeddings are not supported for this processing path.") "Video embeddings are not supported for this processing path.")
else: else:
pixel_values_videos = video_input["pixel_values_videos"].type( pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values_videos.device) pixel_values_videos.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
...@@ -1479,7 +1405,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1479,7 +1405,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
use_rope=True, use_rope=True,
window_size=-1, window_size=-1,
) )
video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw)) video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
return video_embeds return video_embeds
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
...@@ -1541,8 +1467,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1541,8 +1467,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_input: Optional[KeyeImagePixelInputs] = None, image_input: Optional[Any] = None,
video_input: Optional[KeyeVideoPixelInputs] = None, video_input: Optional[Any] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None: if image_input is not None:
...@@ -1572,7 +1498,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1572,7 +1498,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL. """Run forward pass for Keye-VL.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
...@@ -1591,14 +1517,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1591,14 +1517,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed. `None` if no videos are passed.
""" """
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None: elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None: if image_input is None and video_input is None:
inputs_embeds = None inputs_embeds = None
else: else:
...@@ -1619,6 +1543,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1619,6 +1543,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -1631,7 +1556,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1631,7 +1556,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
...@@ -1639,6 +1563,122 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, ...@@ -1639,6 +1563,122 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
"""Get the module prefix in multimodal models.""" """Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field( return MultiModelKeys.from_string_field(
language_model="language_model", language_model="language_model",
connector="visual.", connector="mlp_AR.",
tower_model="mlp_AR.", tower_model="visual.",
) )
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)
def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
def _process_video_input(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
video_type = video_input["type"]
video_grid_thw = video_input["video_grid_thw"]
pixel_values_videos = video_input.get("pixel_values_videos", None)
return tuple(
self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos))
This diff is collapsed.
...@@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = { ...@@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = {
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment