Unverified Commit 9831aec4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Dynamic image size support for VLMs (#5276)


Signed-off-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarywang96 <ywang@roblox.com>
Co-authored-by: default avatarxwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent 482045ee
......@@ -8,7 +8,7 @@ Input Processing
vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors.
Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
Currently, this mechanism is only utilized in :ref:`multi-modal models <multi_modality>` for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.
Guides
......
.. _adding_a_new_multimodal_model:
Adding a New Multimodal Model
=============================
This document provides a high-level guide on integrating a :ref:`multi-modal model <multi_modality>` into vLLM.
.. note::
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
.. tip::
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
We will be happy to help you out!
1. Set up the base vLLM model
-----------------------------
As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model in vLLM, but note the following:
- You should additionally implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
.. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The model class does not have to be named :code:`*ForCausalLM`.
Check out `the HuggingFace Transformers documentation <https://huggingface.co/docs/transformers/model_doc/auto#multimodal>`__ for some examples.
- While implementing the :meth:`~torch.nn.Module.forward` method, reserve a keyword parameter
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
.. code-block:: diff
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
+ pixel_values: torch.Tensor,
) -> SamplerOutput:
2. Register input mappers
-------------------------
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
.. seealso::
:ref:`input_processing_pipeline`
3. (Optional) Register dummy data
---------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
4. (Optional) Register input processor
--------------------------------------
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's :meth:`~torch.nn.Module.forward` call.
You can register input processors via :meth:`INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
- Insert static number of image tokens: `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Insert dynamic number of image tokens: `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
.. _multi_modality:
Multi-Modality
==============
......@@ -8,12 +10,18 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data``
which allows you to pass in multi-modal input alongside text and token prompts.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, please follow :ref:`the guide for adding a new multimodal model. <adding_a_new_multimodal_model>`.
# TODO: Add more instructions on how to do that once embeddings is in.
Guides
++++++
.. toctree::
:maxdepth: 1
adding_multimodal_model
Module Contents
+++++++++++++++
......@@ -35,6 +43,10 @@ Base Classes
:members:
:show-inheritance:
.. autoclass:: vllm.multimodal.MultiModalInputs
:members:
:show-inheritance:
.. autoclass:: vllm.multimodal.MultiModalPlugin
:members:
:show-inheritance:
......
......@@ -23,7 +23,6 @@ The following :ref:`engine arguments <engine_args>` are specific to VLMs:
Currently, the support for vision language models on vLLM has the following limitations:
* Only single image input is supported per text prompt.
* Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means our LLaVA-NeXT output may not exactly match the huggingface implementation.
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
......@@ -42,12 +41,17 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
)
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
.. note::
......@@ -57,8 +61,8 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
.. code-block:: python
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Load the image using PIL.Image
image = ...
......@@ -74,8 +78,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
.. important::
We will remove the need to format image tokens in a future release. Afterwards, the input text will follow the same format as that for the original HuggingFace model.
Online OpenAI Vision API Compatible Inference
----------------------------------------------
......@@ -103,6 +105,11 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
--chat-template template_llava.jinja
.. important::
Currently, you have to specify ``image_feature_size`` to support memory profiling.
To avoid OOM during runtime, you should set this to the maximum value supported by the model.
The calculation of feature size is specific to the model. For more details, please refer to
the function :code:`get_<model_name>_image_feature_size` inside the corresponding model file.
We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration.
To consume the server, you can use the OpenAI client like in the example below:
......@@ -121,6 +128,8 @@ To consume the server, you can use the OpenAI client like in the example below:
messages=[{
"role": "user",
"content": [
# NOTE: The prompt formatting with the image token `<image>` is not needed
# since the prompt will be processed automatically by the API server.
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
......@@ -144,5 +153,4 @@ A full code example can be found in `examples/openai_vision_api_client.py <https
export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
.. note::
The prompt formatting with the image token ``<image>`` is not needed when serving VLMs with the API server since the prompt will be
processed automatically by the server.
There is no need to format the prompt in the API request since it will be handled by the server.
......@@ -17,8 +17,7 @@ def run_llava():
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
image = Image.open("images/stop_sign.jpg")
......
......@@ -5,22 +5,17 @@ from PIL import Image
from vllm import LLM, SamplingParams
# Dynamic image input is currently not supported and therefore
# a fixed image input shape and its corresponding feature size is required.
# See https://github.com/vllm-project/vllm/pull/4199 for the complete
# configuration matrix.
def run_llava_next():
llm = LLM(
model="llava-hf/llava-v1.6-mistral-7b-hf",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=1176,
# Use the maximum possible value for memory profiling
image_feature_size=2928,
)
prompt = "[INST] " + "<image>" * 1176 + (
"\nWhat is shown in this image? [/INST]")
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8,
......
......@@ -5,6 +5,9 @@ from PIL import Image
from vllm import LLM, SamplingParams
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"
......@@ -18,7 +21,8 @@ def run_phi3v():
trust_remote_code=True,
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
# Use the maximum possible value for memory profiling
image_feature_size=2653,
max_num_seqs=5,
)
......@@ -26,8 +30,6 @@ def run_phi3v():
# single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
......
import contextlib
import gc
import os
import sys
from collections import UserList
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
TypedDict, TypeVar)
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypeVar)
import pytest
import torch
......@@ -22,13 +23,10 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.multimodal.utils import fetch_image
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
if TYPE_CHECKING:
# it will call torch.cuda.device_count()
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
_TEST_DIR = os.path.dirname(__file__)
......@@ -47,30 +45,42 @@ def _read_prompts(filename: str) -> List[str]:
@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
def for_hf(self) -> Image.Image:
return self.pil_image
if self.name == "boardwalk":
return fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
)
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
class _ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str
boardwalk: str
if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _ImageAssetsBase(UserList):
pass
else:
class _ImageAssetsBase(UserList[ImageAsset]):
pass
class _ImageAssets(UserList):
class _ImageAssets(_ImageAssetsBase):
def __init__(self) -> None:
super().__init__(
[ImageAsset("stop_sign"),
ImageAsset("cherry_blossom")])
super().__init__([
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
ImageAsset("boardwalk")
])
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
"""
......@@ -79,7 +89,10 @@ class _ImageAssets(UserList):
The order of the returned prompts matches the order of the
assets when iterating through this object.
"""
return [prompts["stop_sign"], prompts["cherry_blossom"]]
return [
prompts["stop_sign"], prompts["cherry_blossom"],
prompts["boardwalk"]
]
IMAGE_ASSETS = _ImageAssets()
......@@ -220,7 +233,7 @@ class HfRunner:
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)
......@@ -255,7 +268,7 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
......@@ -291,19 +304,30 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
all_logprobs: List[List[torch.Tensor]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
output = self.model.generate(
self.wrap_device(input_ids),
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs = []
seq_logprobs: List[torch.Tensor] = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
......@@ -323,20 +347,32 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate(
self.wrap_device(input_ids),
**self.wrap_device(inputs),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
seq_logprobs: List[torch.Tensor] = []
......@@ -431,7 +467,7 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List["MultiModalDataDict"]] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
......@@ -439,7 +475,7 @@ class VllmRunner:
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = image
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
......@@ -462,10 +498,19 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(prompts,
if images is not None:
assert len(prompts) == len(images)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
......@@ -480,7 +525,7 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[List["MultiModalDataDict"]] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
......@@ -492,11 +537,14 @@ class VllmRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
......
......@@ -30,9 +30,10 @@ else:
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets,
tensor_parallel_size: int, dtype: str,
max_tokens: int) -> None:
tensor_parallel_size: int, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(
f"Need at least {tensor_parallel_size} GPUs to run the test.")
......@@ -44,8 +45,10 @@ def test_models(hf_runner, vllm_runner, image_assets,
vllm_runner,
image_assets,
model_and_config=model_and_vl_config[0],
size_factors=[1.0],
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
......@@ -4,18 +4,21 @@ import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"<image>\nUSER: What is the season?\nASSISTANT:",
"USER: <image>\nWhat is the season?\nASSISTANT:",
"boardwalk":
"USER: <image>\nWhat's in this image?\nASSISTANT:",
})
......@@ -37,27 +40,34 @@ model_and_vl_config = [
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
output_ids, output_str = vllm_output
output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "")
assert hf_output_str[0] == " "
hf_output_str = hf_output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
......@@ -66,8 +76,10 @@ def run_test(
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
......@@ -81,61 +93,85 @@ def run_test(
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm() for asset in image_assets]
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
from typing import List, Tuple
import re
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS
from .utils import check_outputs_equal
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
......@@ -15,21 +18,20 @@ _PREFACE = (
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
f"{_PREFACE} <image>\nUSER: What's the content of the image?\nASSISTANT:",
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
"cherry_blossom":
f"{_PREFACE} <image>\nUSER: What is the season?\nASSISTANT:",
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
"boardwalk":
f"{_PREFACE} USER: <image>\nWhat's in this image? ASSISTANT:",
})
def iter_llava_next_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = {
(336, 336): 1176,
(672, 672): 2928,
(1344, 336): 1944,
(336, 1344): 1890,
(336, 336): 2928,
}
for (h, w), f in image_hw_to_feature_size.items():
......@@ -47,37 +49,55 @@ model_and_vl_config = [
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
output_ids, output_str = vllm_output
output_ids, output_str, out_logprobs = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, " ")
return hf_output_ids, hf_output_str
hf_output_str = re.sub(fr"({image_token_str})+", "", output_str)
assert hf_output_str[0] == " "
hf_output_str = hf_output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
......@@ -88,37 +108,46 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
dtype=dtype,
max_model_len=4096,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(
model_id,
dtype=dtype,
# should be greater than image_feature_size
max_model_len=4096,
enforce_eager=True,
**vlm_config.as_cli_args_dict(),
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
import re
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_outputs_equal
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"cherry_blossom":
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", # noqa: E501
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
"boardwalk":
"<|user|>\n<|image_1|>\nWhat's in this image?<|end|>\n<|assistant|>\n",
})
def iter_phi3v_configs(model_name: str):
# Need to use the max possible feature size for profile_run
image_hw_to_feature_size = {
(1008, 1344): 1921,
(2016, 2688): 1933,
(1008, 1344): 2653,
}
for (h, w), f in image_hw_to_feature_size.items():
......@@ -39,29 +43,29 @@ model_and_vl_config = [
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
output_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
hf_output_ids = [
token_id if token_id != image_token_id else 0
for idx, token_id in enumerate(output_ids)
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") \
.replace("<s>", " ").replace("<|user|>", "") \
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
assert output_str_without_image[0] == " "
output_str_without_image = output_str_without_image[1:]
hf_output_str = output_str_without_image.replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ")
return hf_output_ids, hf_output_str
tokenizer = AutoTokenizer.from_pretrained(model_id)
hf_output_ids = tokenizer.encode(output_str_without_image)
assert hf_output_ids[0] == 1
hf_output_ids = hf_output_ids[1:]
return hf_output_ids, hf_output_str, out_logprobs
target_dtype = "half"
......@@ -75,8 +79,10 @@ def run_test(
image_assets: _ImageAssets,
model_and_config: Tuple[str, VisionLanguageConfig],
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
......@@ -90,73 +96,91 @@ def run_test(
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model_id,
max_model_len=2048,
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm() for asset in image_assets]
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images)
for prompts, vllm_images in inputs_per_image
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(
HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id)
check_outputs_equal(
hf_outputs,
[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, vlm_config, model_id)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
dtype: str, max_tokens: int) -> None:
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model_and_config,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
from typing import Dict, List, Tuple
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
TokensText = Tuple[List[int], str]
def check_outputs_equal(outputs_0_lst: List[TokensText],
outputs_1_lst: List[TokensText], name_0: str,
name_1: str):
def check_outputs_equal(
*,
outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
"""
Compare the two sequences generated by different models,
which should be equal.
......@@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText],
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
name_1: str):
def check_logprobs_close(
*,
outputs_0_lst: Sequence[TokensTextLogprobs],
outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str,
name_1: str,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
......@@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (
fail_msg = (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
assert logprobs_elem_0 is not None, fail_msg
assert logprobs_elem_1 is not None, fail_msg
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
# Break out since sequences will now diverge.
break
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
......@@ -4,12 +4,12 @@ from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from ..conftest import _STR_DTYPE_TO_TORCH_DTYPE
from vllm.multimodal.utils import rescale_image_size
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_clip_image_processor(image_assets, dtype):
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
......@@ -26,13 +26,15 @@ def test_clip_image_processor(image_assets, dtype):
)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess(
asset.pil_image,
image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
{"image": image},
)
assert hf_result.keys() == vllm_result.keys()
......@@ -44,12 +46,10 @@ def test_clip_image_processor(image_assets, dtype):
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_llava_next_image_processor(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, LlavaNextImageProcessor)
......@@ -65,13 +65,15 @@ def test_llava_next_image_processor(image_assets, dtype):
)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess(
asset.pil_image,
image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
{"image": image},
)
assert hf_result.keys() == vllm_result.keys()
......@@ -81,36 +83,3 @@ def test_llava_next_image_processor(image_assets, dtype):
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.xfail(
reason="Example image pixels were not processed using HuggingFace")
@pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(image_assets, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
)
for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
)
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
{"image": asset.pil_image},
)
assert image_result.keys() == tensor_result.keys()
for key, image_arr in image_result.items():
tensor_arr: np.ndarray = tensor_result[key].numpy()
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
......@@ -5,10 +5,9 @@ from typing import Dict, Tuple
import numpy as np
import pytest
import pytest_asyncio
from PIL import Image
from vllm.multimodal.utils import ImageFetchAiohttp
from vllm.multimodal.utils import ImageFetchAiohttp, fetch_image
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
......@@ -19,12 +18,9 @@ TEST_IMAGE_URLS = [
]
@pytest_asyncio.fixture(scope="session")
async def url_images() -> Dict[str, Image.Image]:
return {
image_url: await ImageFetchAiohttp.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
@pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
def get_supported_suffixes() -> Tuple[str, ...]:
......@@ -41,7 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
@pytest.mark.asyncio
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await ImageFetchAiohttp.fetch_image(image_url)
assert _image_equals(image_sync, image_async)
@pytest.mark.asyncio(scope="module")
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
......@@ -68,8 +72,11 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
data_image = await ImageFetchAiohttp.fetch_image(data_url)
data_image_sync = fetch_image(data_url)
if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image)
assert _image_equals(url_image, data_image_sync)
else:
pass # Lossy format; only check that image can be opened
data_image_async = await ImageFetchAiohttp.fetch_image(data_url)
assert _image_equals(data_image_sync, data_image_async)
......@@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
import torch
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -1303,16 +1303,6 @@ class VisionLanguageConfig:
image_input_shape: tuple
image_feature_size: int
#TODO(ywang96): make this a cached property once we refactor the
# VisionLanguageConfig class.
def get_image_token_text(
self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
"""Get the image token placeholder text to be inserted into the
text prompt and the string representation of the image token id.
"""
image_token_str = tokenizer.decode(self.image_token_id)
return image_token_str * self.image_feature_size, image_token_str
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
......
import codecs
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional)
from typing import Sequence as GenericSequence
......@@ -10,7 +11,7 @@ from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb,
......@@ -27,8 +28,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_image,
get_full_image_text_prompt)
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
......@@ -97,6 +97,36 @@ class OpenAIServingChat(OpenAIServing):
logger.warning(
"No chat template provided. Chat API will not work.")
@cached_property
def image_token_str(self) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self.model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
"paligemma"):
# These models do not use image tokens in the prompt
return None
# The default behaviour assumes that the image token is
# available to the tokenizer.
# (Suitable for LLaVA, Idefics2, DeepSeek-VL)
vlm_config = self.model_config.multimodal_config
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
image_token_id = vlm_config.image_token_id
if vlm_config.image_token_id is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model does not specify an image token.")
return self.tokenizer.decode(image_token_id)
def _parse_chat_message_content_parts(
self,
role: str,
......@@ -105,21 +135,26 @@ class OpenAIServingChat(OpenAIServing):
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
vlm_config: Optional[VisionLanguageConfig] = getattr(
self.engine.engine, "vision_language_config", None)
model_config = getattr(self.engine.engine, "model_config", None)
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if vlm_config is None:
raise ValueError(
"'image_url' input is not supported as the loaded "
"model is not multimodal.")
assert self.tokenizer is not None
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
image_token_str = self.image_token_str
if image_token_str is not None:
if any(image_token_str in text for text in texts):
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
texts.append(image_token_str)
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
......@@ -128,43 +163,13 @@ class OpenAIServingChat(OpenAIServing):
"'image_url.detail' is currently not supported and "
"will be ignored.")
mm_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(mm_future)
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if vlm_config is not None and len(mm_futures):
assert len(
mm_futures
) == 1, "Multiple 'image_url' input is currently not supported."
(image_token_prompt,
image_token_str) = vlm_config.get_image_token_text(self.tokenizer)
# NOTE: If image token string (e.g, <image>) is already present
# in the text prompt, we assume it follows the same format required
# by the engine.
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
messages = [
ConversationMessage(role=role, content=text_prompt)
]
else:
full_prompt = get_full_image_text_prompt(
image_prompt=image_token_prompt,
text_prompt=text_prompt,
config=model_config)
messages = [
ConversationMessage(role=role, content=full_prompt)
]
else:
messages = [ConversationMessage(role=role, content=text_prompt)]
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
......@@ -267,7 +272,7 @@ class OpenAIServingChat(OpenAIServing):
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if mm_data is not None:
if mm_data:
inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
......
......@@ -36,6 +36,7 @@ class OpenAIServing:
super().__init__()
self.engine = engine
self.model_config = model_config
self.max_model_len = model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
......
......@@ -140,7 +140,8 @@ class InputRegistry:
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]
See also:
:ref:`adding_a_new_multimodal_model`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
......
......@@ -8,10 +8,14 @@ from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from vllm.config import ModelConfig
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
......@@ -64,6 +68,39 @@ def dummy_image_for_clip(
return {"image": image}
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment