Unverified Commit cb234955 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up input processing (#17582)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 3a500cd0
...@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download ...@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets from ....conftest import ImageTestAssets
...@@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets ...@@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"] DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
@torch.inference_mode()
def run_intern_vit_test( def run_intern_vit_test(
image_assets: ImageTestAssets, image_assets: ImageTestAssets,
model_id: str, model_id: str,
...@@ -21,11 +23,12 @@ def run_intern_vit_test( ...@@ -21,11 +23,12 @@ def run_intern_vit_test(
dtype: str, dtype: str,
): ):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN) model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
img_processor = CLIPImageProcessor.from_pretrained(model) img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets] images = [asset.pil_image for asset in image_assets]
pixel_values = [ pixel_values = [
img_processor(images, return_tensors='pt').pixel_values.to(dtype) img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
for images in images for images in images
] ]
...@@ -34,7 +37,7 @@ def run_intern_vit_test( ...@@ -34,7 +37,7 @@ def run_intern_vit_test(
config.norm_type = "rms_norm" config.norm_type = "rms_norm"
hf_model = AutoModel.from_pretrained(model, hf_model = AutoModel.from_pretrained(model,
torch_dtype=dtype, torch_dtype=torch_dtype,
trust_remote_code=True).to("cuda") trust_remote_code=True).to("cuda")
hf_outputs_per_image = [ hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state hf_model(pixel_value.to("cuda")).last_hidden_state
...@@ -48,7 +51,7 @@ def run_intern_vit_test( ...@@ -48,7 +51,7 @@ def run_intern_vit_test(
del hf_model del hf_model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", dtype) vllm_model = vllm_model.to("cuda", torch_dtype)
vllm_outputs_per_image = [ vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda")) vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values for pixel_value in pixel_values
...@@ -66,9 +69,8 @@ def run_intern_vit_test( ...@@ -66,9 +69,8 @@ def run_intern_vit_test(
"OpenGVLab/InternViT-300M-448px", "OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5", "OpenGVLab/InternViT-6B-448px-V1-5",
]) ])
@pytest.mark.parametrize("dtype", [torch.half]) @pytest.mark.parametrize("dtype", ["half"])
@torch.inference_mode() def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
def test_models(image_assets, model_id, dtype: str) -> None:
run_intern_vit_test( run_intern_vit_test(
image_assets, image_assets,
model_id, model_id,
......
...@@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt["prompt_token_ids"] = [0 prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2] ] * prompt["prompt_embeds"].shape[-2]
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
processed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
......
...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors from vllm.logits_process import get_bad_words_logits_processors
...@@ -759,11 +759,6 @@ class LLMEngine: ...@@ -759,11 +759,6 @@ class LLMEngine:
seq_len = prompt["prompt_embeds"].shape[0] seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len prompt["prompt_token_ids"] = [0] * seq_len
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
processed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
...@@ -782,27 +777,6 @@ class LLMEngine: ...@@ -782,27 +777,6 @@ class LLMEngine:
priority=priority, priority=priority,
) )
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
self, self,
request_id: str, request_id: str,
...@@ -2049,6 +2023,12 @@ class LLMEngine: ...@@ -2049,6 +2023,12 @@ class LLMEngine:
else: else:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
if tokenizer is not None:
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len: if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
......
...@@ -83,6 +83,9 @@ class EngineClient(ABC): ...@@ -83,6 +83,9 @@ class EngineClient(ABC):
else: else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt") prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data") multi_modal_data = processed_inputs.get("multi_modal_data")
......
...@@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity, ...@@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens) _validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
...@@ -567,10 +567,12 @@ class LLM: ...@@ -567,10 +567,12 @@ class LLM:
mm_kwargs["mm_processor_kwargs"] = prompt[ mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"] "mm_processor_kwargs"]
if is_token_prompt(prompt): if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"] prompt_tokens = prompt["prompt_token_ids"]
else: else:
prompt_tokens = tokenizer.encode(prompt["prompt"]) prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append( instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs)) BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
......
...@@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict): ...@@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
...@@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict): ...@@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor prompt_embeds: torch.Tensor
"""The embeddings of the prompt.""" """The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs: def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values.""" """Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs( inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
type="embeds",
prompt_embeds=prompt_embeds, if cache_salt is not None:
) inputs["cache_salt"] = cache_salt
return inputs return inputs
......
...@@ -6,9 +6,9 @@ from typing_extensions import TypeIs ...@@ -6,9 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt, from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
ProcessorInputs, PromptType, SingletonInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
SingletonPrompt, TextPrompt, TokensPrompt) TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict): ...@@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
content: EmbedsPrompt content: EmbedsPrompt
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload @overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
... ...
...@@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: ...@@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
... ...
def parse_singleton_prompt( def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
if isinstance(prompt, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt) return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict): elif isinstance(prompt, dict):
...@@ -131,23 +132,11 @@ def parse_singleton_prompt( ...@@ -131,23 +132,11 @@ def parse_singleton_prompt(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs, inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]: ) -> tuple[Optional[SingletonInputs], SingletonInputs]:
......
This diff is collapsed.
...@@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
placeholders = mm_placeholders.get(modality, []) placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count: if len(placeholders) != item_count:
# NOTE: If you are a model developer, this can also arise from
# an inconsistency between `_call_hf_processor` and
# `_get_mm_fields_config` implementations
raise RuntimeError( raise RuntimeError(
f"Expected there to be {item_count} prompt updates " f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but " f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt updates! " f"instead found {len(placeholders)} prompt updates! "
"Either the prompt text has missing/incorrect tokens for " "This is likely because you forgot to include input "
"multi-modal inputs, or there is a problem with your " "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"implementation of merged multi-modal processor for this " "in the prompt. If the model has a chat template, make "
"model (usually arising from an inconsistency between " "sure you have applied it before calling `LLM.generate`.")
"`_call_hf_processor` and `_get_prompt_updates`).")
def _maybe_apply_prompt_updates( def _maybe_apply_prompt_updates(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment