Unverified Commit 7c8271cd authored by Kwai-Keye's avatar Kwai-Keye Committed by GitHub
Browse files

[Model]: support KeyeVL-1_5-8B (#23838)


Signed-off-by: default avatarwangruitao <wangruitao@kuaishou.com>
Co-authored-by: default avatarwangruitao <wangruitao@kuaishou.com>
parent 3e330fcb
......@@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
......
......@@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# Keye-VL-1.5
def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1.5-8B"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
trust_remote_code=True,
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"
prompts = [
(
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Kimi-VL
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1648,6 +1679,7 @@ model_example_map = {
"interns1": run_interns1,
"internvl_chat": run_internvl,
"keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl,
"llama4": run_llama4,
"llava": run_llava,
......
......@@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
......@@ -1209,6 +1246,7 @@ model_example_map = {
"interns1": load_interns1,
"internvl_chat": load_internvl,
"keye_vl": load_keye_vl,
"keye_vl1_5": load_keye_vl1_5,
"kimi_vl": load_kimi_vl,
"llama4": load_llama4,
"llava": load_llava,
......
......@@ -293,6 +293,7 @@ def _test_processing_correctness_one(
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
"OpenGVLab/InternVL3_5-30B-A3B",
"Kwai-Keye/Keye-VL-8B-Preview",
"Kwai-Keye/Keye-VL-1_5-8B",
"moonshotai/Kimi-VL-A3B-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"llava-hf/llava-1.5-7b-hf",
......
......@@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
trust_remote_code=True),
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501
trust_remote_code=True),
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
......
......@@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len,
seq_len=seq_len,
)
elif "KeyeVL1_5" in hf_config.model_type:
return cls._keye_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
......@@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod
def _keye_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w)).long().flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
@classmethod
def _vl_get_input_positions_tensor(
cls,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
import numpy as np
import torch
......@@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend
logger = init_logger(__name__)
_MAX_FRAMES_PER_VIDEO = 16
_MAX_IMAGE_SIZE = 9999999
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
factor: int,
min_pixels: int,
max_pixels: int,
):
if height < factor:
logger.warning(
......@@ -887,9 +885,9 @@ class Projector(nn.Module):
def forward(
self,
image_features: torch.Tensor,
image_features: Union[torch.Tensor, list[torch.Tensor]],
image_grid_thw: list[tuple[int, int, int]],
) -> torch.Tensor:
) -> Union[torch.Tensor, list[torch.Tensor]]:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
......@@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
class KeyeProcessingInfo(BaseProcessingInfo):
def get_max_image_size(self) -> int:
return 9999999 #_MAX_IMAGE_SIZE
def get_max_frame_per_video(self) -> int:
return 16 #_MAX_FRAMES_PER_VIDEO
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
......@@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self, ) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=_MAX_IMAGE_SIZE,
image_height=_MAX_IMAGE_SIZE,
image_width=self.get_max_image_size(),
image_height=self.get_max_image_size(),
image_processor=None,
)
return max_image_size
......@@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo):
max_image_tokens)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO,
self.get_max_frame_per_video(),
)
return max(max_frames_per_video, 1)
......@@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo):
)
class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
_I = TypeVar("_I", bound=KeyeProcessingInfo)
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
......@@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
return mm_data
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
...
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
......@@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return _keye_field_config(hf_inputs)
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
SupportsPP):
class BaseKeyeModule(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
raise ValueError("Only image or video modality is supported")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
......@@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = Projector(
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
......@@ -1294,102 +1305,16 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
@abstractmethod
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise ValueError("Need projector")
def _process_image_input(
self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]:
def _process_image_input(self,
image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
image_grid_hws = list()
sample_indices = list()
......@@ -1434,18 +1359,20 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
return image_embeds
def _process_video_input(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
def _process_video_embeds(
self,
video_type: Literal["video_embeds", "pixel_values_videos"],
video_grid_thw: list[torch.Tensor],
pixel_values_videos: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
siglip_position_ids = list()
video_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
video_grid_thw = video_input["video_grid_thw"]
assert video_grid_thw.ndim == 2
for idx, thaw in enumerate(video_grid_thw):
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
for idx, sub_thw in enumerate(video_grid_thw):
thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
video_grid_hws.append(thw_tuple)
......@@ -1455,12 +1382,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
if video_input["type"] == "video_embeds":
if video_type == "video_embeds":
raise ValueError(
"Video embeddings are not supported for this processing path.")
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values_videos.device)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
......@@ -1479,7 +1405,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
use_rope=True,
window_size=-1,
)
video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw))
video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
return video_embeds
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
......@@ -1541,8 +1467,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[KeyeImagePixelInputs] = None,
video_input: Optional[KeyeVideoPixelInputs] = None,
image_input: Optional[Any] = None,
video_input: Optional[Any] = None,
) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
......@@ -1572,7 +1498,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL.
"""Run forward pass for Keye-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
......@@ -1591,14 +1517,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
......@@ -1619,6 +1543,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
......@@ -1631,7 +1556,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......@@ -1639,6 +1563,122 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
"""Get the module prefix in multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.",
tower_model="mlp_AR.",
connector="mlp_AR.",
tower_model="visual.",
)
@MULTIMODAL_REGISTRY.register_processor(
KeyeMultiModalProcessor,
info=KeyeProcessingInfo,
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
SupportsLoRA, SupportsPP):
def _build_projector(self,
text_config: PretrainedConfig,
vision_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)
def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[KeyeVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None and video_embeds is None:
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,
)
def _process_video_input(
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
video_type = video_input["type"]
video_grid_thw = video_input["video_grid_thw"]
pixel_values_videos = video_input.get("pixel_values_videos", None)
return tuple(
self._process_video_embeds(video_type, video_grid_thw,
pixel_values_videos))
This diff is collapsed.
......@@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = {
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment