Unverified Commit bc55d130 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[VLM] Implement merged multimodal processor for Mllama (#11427)

parent d88c8666
......@@ -7,11 +7,11 @@ import torch
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
MllamaForConditionalGeneration)
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -21,6 +21,7 @@ from ....utils import large_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3
MLLAMA_IMAGE_TOKEN_ID = 128256
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
......@@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
def test_explicit_implicit_prompt(
image_assets: _ImageAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
# yapf: disable
prompts = [
# explicit prompt
{
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {"image": stop_sign},
},
"decoder_prompt": {
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
}
},
{
"encoder_prompt": "Not <|image|>",
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
# implicit prompt
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"multi_modal_data": {"image": stop_sign},
},
{
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
]
# yapf: enable
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
n_prompts = len(prompts)
explicit_outputs = outputs[:n_prompts // 2]
implicit_outputs = outputs[n_prompts // 2:]
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
assert exp_output.outputs[0].text == imp_output.outputs[0].text
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
......@@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
images=images)
class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices_and_output",
......@@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
dummy = DummyModel()
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
.get_cross_attention_mask(dummy,
......@@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
dummy = DummyModel()
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
.get_full_text_row_masked_out_mask(dummy,
......
......@@ -85,6 +85,14 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches):
mm_data = {
k:
......@@ -122,7 +130,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
......@@ -131,7 +139,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
......@@ -155,6 +163,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
......
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Tuple, Union, cast
from typing_extensions import assert_never
......@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
......@@ -495,6 +496,51 @@ class InputPreprocessor:
decoder=decoder_inputs,
)
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
......@@ -539,7 +585,6 @@ class InputPreprocessor:
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
......@@ -547,11 +592,26 @@ class InputPreprocessor:
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
......@@ -583,11 +643,27 @@ class InputPreprocessor:
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = await self._prompt_to_llm_inputs_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
......
......@@ -350,7 +350,8 @@ class InputRegistry:
)
processor = mm_registry.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
......
This diff is collapsed.
......@@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt: str
"""The processed encoder prompt text."""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""
......@@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
if TYPE_CHECKING:
......@@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
@abstractmethod
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
"""Create input prompt for the encoder."""
raise NotImplementedError
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
encoder_inputs = super().apply(
encoder_prompt,
mm_data,
hf_processor_mm_kwargs,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer = self.info.get_tokenizer()
if isinstance(prompt, str):
decoder_prompt = prompt
decoder_prompt_ids = encode_tokens(tokenizer,
prompt,
add_special_tokens=False)
else:
decoder_prompt = decode_tokens(tokenizer, prompt)
decoder_prompt_ids = prompt
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs.update({
"prompt": decoder_prompt,
"prompt_token_ids": decoder_prompt_ids
})
return mm_inputs
......@@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
def get_dummy_data(
self,
seq_len: int,
is_encoder_data: bool = False,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
......@@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment