Unverified Commit 5b595327 authored by litianjian's avatar litianjian Committed by GitHub
Browse files

[Model][VLM] Add LLaVA-Onevision model support (#8486)


Co-authored-by: default avatarlitianjian <litianjian@bytedance.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ca2b628b
......@@ -244,6 +244,11 @@ Multimodal Language Models
- Video
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note)
-
* - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision
- Image\ :sup:`+` / Video
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. (see note)
-
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
......@@ -288,7 +293,7 @@ Multimodal Language Models
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
.. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
For :code:`LLaVA-NeXT-Video`, :code:`LLaVA-Onevision` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command:
.. code-block:: bash
......
......@@ -14,7 +14,8 @@ from vllm.utils import FlexibleArgumentParser
# LLaVA-1.5
def run_llava(question):
def run_llava(question, modality):
assert modality == "image"
prompt = f"USER: <image>\n{question}\nASSISTANT:"
......@@ -24,7 +25,8 @@ def run_llava(question):
# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question):
def run_llava_next(question, modality):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
......@@ -34,15 +36,35 @@ def run_llava_next(question):
# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(question):
def run_llava_next_video(question, modality):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
stop_token_ids = None
return llm, prompt, stop_token_ids
# LLaVA-OneVision
def run_llava_onevision(question, modality):
if modality == "video":
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n"
elif modality == "image":
prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=32768)
stop_token_ids = None
return llm, prompt, stop_token_ids
# Fuyu
def run_fuyu(question):
def run_fuyu(question, modality):
assert modality == "image"
prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b")
......@@ -51,7 +73,8 @@ def run_fuyu(question):
# Phi-3-Vision
def run_phi3v(question):
def run_phi3v(question, modality):
assert modality == "image"
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
# Note: The default setting of max_num_seqs (256) and
......@@ -70,7 +93,8 @@ def run_phi3v(question):
# PaliGemma
def run_paligemma(question):
def run_paligemma(question, modality):
assert modality == "image"
# PaliGemma has special prompt format for VQA
prompt = "caption en"
......@@ -80,7 +104,8 @@ def run_paligemma(question):
# Chameleon
def run_chameleon(question):
def run_chameleon(question, modality):
assert modality == "image"
prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b")
......@@ -89,7 +114,8 @@ def run_chameleon(question):
# MiniCPM-V
def run_minicpmv(question):
def run_minicpmv(question, modality):
assert modality == "image"
# 2.0
# The official repo doesn't work yet, so we need to use a fork for now
......@@ -129,7 +155,9 @@ def run_minicpmv(question):
# InternVL
def run_internvl(question):
def run_internvl(question, modality):
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
llm = LLM(
......@@ -155,7 +183,8 @@ def run_internvl(question):
# BLIP-2
def run_blip2(question):
def run_blip2(question, modality):
assert modality == "image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
......@@ -166,7 +195,8 @@ def run_blip2(question):
# Qwen
def run_qwen_vl(question):
def run_qwen_vl(question, modality):
assert modality == "image"
llm = LLM(
model="Qwen/Qwen-VL",
......@@ -180,7 +210,9 @@ def run_qwen_vl(question):
# Qwen2-VL
def run_qwen2_vl(question):
def run_qwen2_vl(question, modality):
assert modality == "image"
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
......@@ -200,6 +232,7 @@ model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
"llava-next-video": run_llava_next_video,
"llava-onevision": run_llava_onevision,
"fuyu": run_fuyu,
"phi3_v": run_phi3v,
"paligemma": run_paligemma,
......@@ -255,7 +288,7 @@ def main(args):
data = mm_input["data"]
question = mm_input["question"]
llm, prompt, stop_token_ids = model_example_map[model](question)
llm, prompt, stop_token_ids = model_example_map[model](question, modality)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
......@@ -306,6 +339,7 @@ if __name__ == "__main__":
parser.add_argument('--modality',
type=str,
default="image",
choices=['image', 'video'],
help='Modality of the input.')
parser.add_argument('--num-frames',
type=int,
......
......@@ -105,9 +105,6 @@ def run_test(
for asset in video_assets
]
for video in videos:
print(video.shape)
if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
......
from typing import List, Optional, Tuple, Type, overload
import pytest
import transformers
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
resize_video, sample_frames_from_video)
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ...utils import check_logprobs_close
# Video test
HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"sample_demo_1":
"<|im_start|>user <video>\nwhy is this video funny? \
<|im_end|><|im_start|>assistant\n"
})
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
video_token_id = config.video_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != video_token_id or output_ids[idx - 1] != video_token_id
]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
videos = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
elif sizes is not None:
inputs_per_video = [(
[prompt for _ in sizes],
[resize_video(video, size) for size in sizes],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_video = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values_videos"] = hf_inputs["pixel_values_videos"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_video = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
vllm_outputs_per_video):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No video
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
dtype, max_tokens, num_logprobs, num_frames) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/videos.
For huggingface runner, we provide the np.ndarray as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
dtype, max_tokens, num_logprobs,
num_frames) -> None:
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
# Image test
_LIMIT_IMAGE_PER_PROMPT = 4
def run_image_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=32768,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image><image><image>\nDescribe 4 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image>\nWhat is the season? \
<|im_end|><|im_start|>assistant\n",
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
run_image_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
......@@ -6,7 +6,8 @@ from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration"
if (model_cls in ("LlavaOnevisionForConditionalGeneration",
"Qwen2VLForConditionalGeneration")
and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release")
......
......@@ -79,7 +79,7 @@ class VideoAsset:
return ret
@property
def np_ndarrays(self) -> List[npt.NDArray]:
def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
......@@ -83,12 +83,14 @@ _MULTIMODAL_MODELS = {
("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaForConditionalGeneration": ("llava",
"LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
......
......@@ -2,6 +2,7 @@
within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
......@@ -84,6 +85,24 @@ def dummy_image_for_clip(
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
......
This diff is collapsed.
......@@ -4,6 +4,7 @@ within a vision language model."""
import math
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from torch import nn
......@@ -89,6 +90,24 @@ def dummy_image_for_siglip(
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_siglip(
hf_config: SiglipVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_siglip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_siglip(
model_config: ModelConfig,
hf_config: SiglipVisionConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment