"vscode:/vscode.git/clone" did not exist on "24cde76a152fbffde30fa2be0d08dcbad490530e"
Unverified Commit 2f707fcb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Multi-input support for LLaVA (#8238)

parent 41e95c52
...@@ -219,7 +219,7 @@ Multimodal Language Models ...@@ -219,7 +219,7 @@ Multimodal Language Models
- -
* - :code:`LlavaForConditionalGeneration` * - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5 - LLaVA-1.5
- Image\ :sup:`E` - Image\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
- -
* - :code:`LlavaNextForConditionalGeneration` * - :code:`LlavaNextForConditionalGeneration`
...@@ -227,6 +227,11 @@ Multimodal Language Models ...@@ -227,6 +227,11 @@ Multimodal Language Models
- Image\ :sup:`E+` - Image\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
- -
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration` * - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma - PaliGemma
- Image\ :sup:`E` - Image\ :sup:`E`
...@@ -237,14 +242,9 @@ Multimodal Language Models ...@@ -237,14 +242,9 @@ Multimodal Language Models
- Image\ :sup:`E+` - Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
- -
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`QWenLMHeadModel` * - :code:`QWenLMHeadModel`
- Qwen - Qwen-VL
- Image - Image\ :sup:`E`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
......
...@@ -278,7 +278,7 @@ class HfRunner: ...@@ -278,7 +278,7 @@ class HfRunner:
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images: if images:
...@@ -314,7 +314,7 @@ class HfRunner: ...@@ -314,7 +314,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
...@@ -351,7 +351,7 @@ class HfRunner: ...@@ -351,7 +351,7 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[List[torch.Tensor]]: ) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = [] all_logprobs: List[List[torch.Tensor]] = []
...@@ -433,8 +433,8 @@ class HfRunner: ...@@ -433,8 +433,8 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None, audios: Optional[PromptAudioInput] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = [] all_logprobs: List[List[Dict[int, float]]] = []
...@@ -671,7 +671,7 @@ class VllmRunner: ...@@ -671,7 +671,7 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional[List[Image.Image]] = None, images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images) outputs = self.generate(prompts, greedy_params, images=images)
......
...@@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, ...@@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
if model.startswith("llava-hf/llava-1.5"): if model.startswith("llava-hf/llava-1.5"):
from ..models.test_llava import models, run_test from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"): elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test from ..models.test_llava_next import run_test # type: ignore[no-redef]
from ..models.test_llava_next import models
elif model.startswith("facebook/chameleon"): elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test from ..models.test_chameleon import run_test # type: ignore[no-redef]
from ..models.test_chameleon import models
else: else:
raise NotImplementedError(f"Unsupported model: {model}") raise NotImplementedError(f"Unsupported model: {model}")
......
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type, overload
import pytest import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
...@@ -8,11 +8,14 @@ from vllm.multimodal.utils import rescale_image_size ...@@ -8,11 +8,14 @@ from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm pytestmark = pytest.mark.vlm
_LIMIT_IMAGE_PER_PROMPT = 4
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:", "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
...@@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, ...@@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test( def run_test(
hf_runner: Type[HfRunner], hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner], vllm_runner: Type[VllmRunner],
...@@ -64,6 +68,78 @@ def run_test( ...@@ -64,6 +68,78 @@ def run_test(
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
): ):
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
...@@ -85,13 +161,6 @@ def run_test( ...@@ -85,13 +161,6 @@ def run_test(
else: else:
mantis_processor = None mantis_processor = None
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
...@@ -100,15 +169,18 @@ def run_test( ...@@ -100,15 +169,18 @@ def run_test(
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model: enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts, vllm_model.generate_greedy_logprobs(prompts,
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
if mantis_processor is not None: if mantis_processor is not None:
...@@ -131,7 +203,7 @@ def run_test( ...@@ -131,7 +203,7 @@ def run_test(
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
images=images) images=images)
for prompts, images in inputs_per_image for prompts, images in inputs
] ]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
...@@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ...@@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
) )
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501
"USER: <image>\nWhat is the season?\nASSISTANT:",
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_length_too_short(vllm_runner, image_assets, model): def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
......
...@@ -105,7 +105,7 @@ def input_processor_for_clip( ...@@ -105,7 +105,7 @@ def input_processor_for_clip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config) image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
......
...@@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = num_blocks * num_patches image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
......
...@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, ...@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
...@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal ...@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx) image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip( return input_processor_for_clip(
...@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
...@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
for img in image_data for img in image_data
] ]
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
......
...@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w, input_width=w,
input_height=h)) input_height=h))
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
......
...@@ -110,7 +110,7 @@ def input_processor_for_siglip( ...@@ -110,7 +110,7 @@ def input_processor_for_siglip(
if isinstance(image_data, Image.Image): if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config) image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor): elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0] num_images, image_feature_size, hidden_size = image_data.shape
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment