Unverified Commit 4ebc0b96 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Proper input validation for multi-modal encoder-decoder models (#16156)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent dc96fd54
...@@ -56,7 +56,7 @@ def run_florence2(): ...@@ -56,7 +56,7 @@ def run_florence2():
def run_mllama(): def run_mllama():
engine_args = EngineArgs( engine_args = EngineArgs(
model="meta-llama/Llama-3.2-11B-Vision-Instruct", model="meta-llama/Llama-3.2-11B-Vision-Instruct",
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
dtype="half", dtype="half",
......
...@@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: ...@@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU. # The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )
......
...@@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
# The configuration below has been confirmed to launch on a single L40 GPU. # The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=4096, max_model_len=8192,
max_num_seqs=16, max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
......
...@@ -18,7 +18,8 @@ models = ["llava-hf/llava-1.5-7b-hf"] ...@@ -18,7 +18,8 @@ models = ["llava-hf/llava-1.5-7b-hf"]
def test_context_length_too_short(vllm_runner, image_assets, model): def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
with pytest.raises(ValueError, match="too long to fit into the model"): with pytest.raises(ValueError,
match="longer than the maximum model length"):
vllm_model = vllm_runner( vllm_model = vllm_runner(
model, model,
max_model_len=128, # LLaVA has a feature size of 576 max_model_len=128, # LLaVA has a feature size of 576
......
...@@ -15,7 +15,7 @@ def v1(run_with_both_engines): ...@@ -15,7 +15,7 @@ def v1(run_with_both_engines):
def test_empty_prompt(): def test_empty_prompt():
llm = LLM(model="openai-community/gpt2", enforce_eager=True) llm = LLM(model="openai-community/gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'): with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
llm.generate([""]) llm.generate([""])
......
...@@ -17,7 +17,7 @@ async def test_empty_prompt(): ...@@ -17,7 +17,7 @@ async def test_empty_prompt():
client = remote_server.get_async_client() client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError, with pytest.raises(openai.BadRequestError,
match=re.compile('.+Prompt cannot be empty.+')): match="decoder prompt cannot be empty"):
await client.completions.create(model=model_name, await client.completions.create(model=model_name,
prompt="", prompt="",
max_tokens=5, max_tokens=5,
......
...@@ -211,7 +211,7 @@ def _run_test( ...@@ -211,7 +211,7 @@ def _run_test(
# max_model_len should be greater than image_feature_size # max_model_len should be greater than image_feature_size
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=3, max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
...@@ -422,7 +422,7 @@ def test_bnb_regression( ...@@ -422,7 +422,7 @@ def test_bnb_regression(
llm = LLM( llm = LLM(
model=model, model=model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
quantization="bitsandbytes", quantization="bitsandbytes",
) )
...@@ -475,7 +475,7 @@ def test_explicit_implicit_prompt( ...@@ -475,7 +475,7 @@ def test_explicit_implicit_prompt(
llm = LLM( llm = LLM(
model=model, model=model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
...@@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens, ...@@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
with global_force_attn_backend_context_manager(attn_backend), vllm_runner( with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
max_model_len=4096, max_model_len=8192,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=1, tensor_parallel_size=1,
limit_mm_per_prompt={"image": limit_mm_per_prompt={"image":
......
...@@ -8,7 +8,7 @@ from contextlib import contextmanager ...@@ -8,7 +8,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Literal, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload from typing import Set, Type, Union, cast, overload
...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType) PromptType, SingletonInputs)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -40,6 +40,7 @@ from vllm.model_executor.guided_decoding import ( ...@@ -40,6 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -2029,29 +2030,57 @@ class LLMEngine: ...@@ -2029,29 +2030,57 @@ class LLMEngine:
lora_request: Optional[LoRARequest]): lora_request: Optional[LoRARequest]):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len if encoder_inputs is not None:
# restricts the decoder prompt length self._validate_model_input(encoder_inputs,
if self.model_config.is_multimodal_model: lora_request,
prompt_inputs = decoder_inputs prompt_type="encoder")
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"] self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
if prompt_ids is None or len(prompt_ids) == 0: def _validate_model_input(
raise ValueError("Prompt cannot be empty") self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
if prompt_type == "encoder" and self.tokenizer is not None:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
model_config = self.model_config
if self.model_config.is_multimodal_model: if model_config.is_multimodal_model:
max_prompt_len = self.model_config.max_model_len mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if len(prompt_ids) > max_prompt_len: if mm_processor.pad_dummy_encoder_prompt:
raise ValueError( return # Skip encoder length check for Whisper
f"The prompt (total length {len(prompt_ids)}) is too long "
f"to fit into the model (context length {max_prompt_len}). " prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the " "Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image " "number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number " "inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.") "of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
......
...@@ -213,8 +213,12 @@ class MultiModalProfiler(Generic[_I]): ...@@ -213,8 +213,12 @@ class MultiModalProfiler(Generic[_I]):
total_len = len(encoder_prompt_token_ids) total_len = len(encoder_prompt_token_ids)
# Encoder-decoder multimodal models only support v0 processor = cast(EncDecMultiModalProcessor, self.processor)
if total_len > seq_len: if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig` # `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once( logger.warning_once(
"The encoder sequence length used for profiling (" "The encoder sequence length used for profiling ("
...@@ -229,11 +233,6 @@ class MultiModalProfiler(Generic[_I]): ...@@ -229,11 +233,6 @@ class MultiModalProfiler(Generic[_I]):
"increase `max_model_len`, reduce `max_num_seqs`, " "increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.") "and/or reduce `mm_counts`.")
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids) return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data( def get_decoder_dummy_data(
......
...@@ -2,16 +2,17 @@ ...@@ -2,16 +2,17 @@
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Union from typing import Literal, Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -287,41 +288,62 @@ class Processor: ...@@ -287,41 +288,62 @@ class Processor:
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len if encoder_inputs is not None:
# restricts the decoder prompt length self._validate_model_input(encoder_inputs,
if self.model_config.is_multimodal_model: lora_request,
prompt_inputs = decoder_inputs prompt_type="encoder")
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"] self._validate_model_input(decoder_inputs,
lora_request,
prompt_type="decoder")
if prompt_ids is None or len(prompt_ids) == 0: def _validate_model_input(
raise ValueError("Prompt cannot be empty") self,
prompt_inputs: SingletonInputs,
lora_request: Optional[LoRARequest],
*,
prompt_type: Literal["encoder", "decoder"],
):
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
max_input_id = max(prompt_ids) if prompt_type == "encoder":
max_allowed = self.tokenizer.get_lora_tokenizer( model_config = self.model_config
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if len(prompt_ids) >= self.model_config.max_model_len: if model_config.is_multimodal_model:
raise ValueError( mm_registry = self.input_preprocessor.mm_registry
f"Prompt length of {len(prompt_ids)} is longer than the " mm_processor = mm_registry.create_processor(
f"maximum model length of {self.model_config.max_model_len}.") model_config, tokenizer=tokenizer)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if self.model_config.is_multimodal_model: if mm_processor.pad_dummy_encoder_prompt:
max_prompt_len = self.model_config.max_model_len return # Skip encoder length check for Whisper
if len(prompt_ids) > max_prompt_len: prompt_ids = prompt_inputs["prompt_token_ids"]
raise ValueError(
f"The prompt (total length {len(prompt_ids)}) is too long " if not prompt_ids:
f"to fit into the model (context length {max_prompt_len}). " raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if self.model_config.is_multimodal_model:
suggestion = (
"Make sure that `max_model_len` is no smaller than the " "Make sure that `max_model_len` is no smaller than the "
"number of text tokens plus multimodal tokens. For image " "number of text tokens plus multimodal tokens. For image "
"inputs, the number of image tokens depends on the number " "inputs, the number of image tokens depends on the number "
"of images, and possibly their aspect ratios as well.") "of images, and possibly their aspect ratios as well.")
else:
suggestion = (
"Make sure that `max_model_len` is no smaller than the "
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment