Unverified Commit 75f81750 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[VLM] Initialize video input support for InternVL models (#18499)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 6ab681bc
......@@ -527,7 +527,7 @@ Specified using `--task generate`.
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎\* | |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | ✅︎ | |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | ✅︎ | | |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | |
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | ✅︎ | ✅︎ | |
......@@ -577,6 +577,9 @@ Specified using `--task generate`.
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
!!! note
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
......
......@@ -330,22 +330,26 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
model_name = "OpenGVLab/InternVL3-2B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<image>"
elif modality == "video":
placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
'content': f"{placeholder}\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
......@@ -357,6 +361,9 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [
token_id for token_id in stop_token_ids if token_id is not None
]
return ModelRequestData(
engine_args=engine_args,
......
......@@ -349,6 +349,17 @@ VLM_TEST_SETTINGS = {
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"intern_vl-video": VLMTestInfo(
models=[
"OpenGVLab/InternVL3-1B",
],
test_type=VLMTestType.VIDEO,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
video_idx_to_prompt=lambda idx: "<video>",
max_model_len=8192,
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
),
"kimi_vl": VLMTestInfo(
models=["moonshotai/Kimi-VL-A3B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
......
......@@ -7,6 +7,8 @@ import types
from pathlib import PosixPath
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import regex as re
import torch
from PIL.Image import Image
......@@ -495,30 +497,74 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Union[Image, list[Image]],
**kwargs):
def __call__(
self,
text: str,
images: Union[Image, list[Image]] = None,
videos: Union[npt.NDArray, list[npt.NDArray]] = None,
**kwargs,
):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_internvl)
image_to_pixel_values_internvl, video_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
videos = [videos] if isinstance(videos, np.ndarray) else videos
if images is not None:
pixel_values_images = [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_images = [
pixel_value.shape[0] for pixel_value in pixel_values_images
]
else:
pixel_values_images, num_patches_images = [], []
if videos is not None:
pixel_values_videos = [
video_to_pixel_values_internvl(
video,
input_size=self.image_size,
min_num=1,
max_num=1,
use_thumbnail=False,
) for video in videos
]
num_patches_videos = [
pixel_value.shape[0] for pixel_value in pixel_values_videos
]
else:
pixel_values_videos, num_patches_videos = [], []
pixel_values = []
while ("<image>" in text) or ("<video>" in text):
image_index = text.find("<image>")
video_index = text.find("<video>")
if image_index == -1 or (video_index > -1
and video_index < image_index):
num_patches = num_patches_videos.pop(0)
pixel_values.append(pixel_values_videos.pop(0))
context_tokens = IMG_START + \
IMG_CONTEXT * self.num_image_token + IMG_END
video_tokens = ''.join([
f'Frame{i+1}: {context_tokens}'
for i in range(num_patches)
])
text = text.replace('<video>', video_tokens, 1)
else:
num_patches = num_patches_images.pop(0)
pixel_values.append(pixel_values_images.pop(0))
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
......
......@@ -258,6 +258,7 @@ def _test_processing_correctness_mistral(
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL3-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"moonshotai/Kimi-VL-A3B-Instruct",
......
......@@ -334,7 +334,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
extras={"2B": "OpenGVLab/InternVL2-2B",
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
trust_remote_code=True),
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
......
......@@ -556,6 +556,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
......
......@@ -25,9 +25,10 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
BaseInternVLDummyInputsBuilder,
BaseInternVLMultiModalProcessor,
BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor, build_transform,
InternVLChatModel, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios)
......@@ -430,8 +431,8 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
):
class H2OVLMultiModalProcessor(
BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]):
def _get_prompt_updates(
self,
......@@ -514,7 +515,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
@MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor,
info=H2OVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
dummy_inputs=BaseInternVLDummyInputsBuilder)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(
......
This diff is collapsed.
......@@ -22,9 +22,10 @@ from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from .intern_vit import InternVisionModel
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor)
from .internvl import (BaseInternVLDummyInputsBuilder,
BaseInternVLMultiModalProcessor,
BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel)
IMG_PAD = "<|vision_pad|>"
......@@ -84,7 +85,8 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
)
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
......@@ -110,7 +112,8 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
}
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(
BaseInternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_updates(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment