Unverified Commit 5a94a198 authored by Zeyu Zhang's avatar Zeyu Zhang Committed by GitHub
Browse files

[Bugfix] Normalize malformed dict prompts that carry token IDs in `prompt` (#40339)


Signed-off-by: default avatarAlchuang22-dev <2584829494@qq.com>
parent f95c11a8
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq import pytest
from vllm.renderers.inputs.preprocess import (
parse_dec_only_prompt,
parse_enc_dec_prompt,
prompt_to_seq,
)
def test_empty_input(): def test_empty_input():
...@@ -39,3 +45,23 @@ def test_dict_input(): ...@@ -39,3 +45,23 @@ def test_dict_input():
{"prompt": "foo"}, {"prompt": "foo"},
{"prompt_token_ids": [1, 2]}, {"prompt_token_ids": [1, 2]},
] ]
def test_parse_dec_only_prompt_rejects_non_string_prompt_field():
with pytest.raises(TypeError, match="Prompt text should be a string"):
parse_dec_only_prompt({"prompt": [1, 2, 3], "cache_salt": "abc"})
def test_parse_dec_only_prompt_rejects_non_string_prompt_list():
with pytest.raises(TypeError, match="Prompt text should be a string"):
parse_dec_only_prompt({"prompt": [1, "x"]})
def test_parse_enc_dec_prompt_rejects_nested_non_string_prompt_field():
with pytest.raises(TypeError, match="Prompt text should be a string"):
parse_enc_dec_prompt(
{
"encoder_prompt": {"prompt": [1, 2, 3]},
"decoder_prompt": {"prompt": [4, 5]},
}
)
...@@ -435,6 +435,11 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -435,6 +435,11 @@ class BaseRenderer(ABC, Generic[_T]):
params: TokenizeParams, params: TokenizeParams,
) -> SingletonTokPrompt: ) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
if not isinstance(prompt.get("prompt"), str):
raise TypeError(
"Expected prompt['prompt'] to be a string before tokenization; "
"use 'prompt_token_ids' for token ID inputs"
)
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = self._tokenize_prompt(prompt, params) prompt = self._tokenize_prompt(prompt, params)
...@@ -466,6 +471,11 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -466,6 +471,11 @@ class BaseRenderer(ABC, Generic[_T]):
params: TokenizeParams, params: TokenizeParams,
) -> SingletonTokPrompt: ) -> SingletonTokPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
if not isinstance(prompt.get("prompt"), str):
raise TypeError(
"Expected prompt['prompt'] to be a string before tokenization; "
"use 'prompt_token_ids' for token ID inputs"
)
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
prompt = await self._tokenize_prompt_async(prompt, params) prompt = await self._tokenize_prompt_async(prompt, params)
......
...@@ -4,7 +4,7 @@ Schemas and utilities for preprocessing inputs. ...@@ -4,7 +4,7 @@ Schemas and utilities for preprocessing inputs.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import ( from vllm.inputs import (
...@@ -116,6 +116,19 @@ that has been standardized into a dictionary. ...@@ -116,6 +116,19 @@ that has been standardized into a dictionary.
""" """
def _validate_prompt_dict(prompt: Mapping[str, object]) -> None:
"""Reject malformed dict prompts before renderer tokenization."""
if (
"prompt" not in prompt
or "prompt_token_ids" in prompt
or "prompt_embeds" in prompt
):
return
if not isinstance(prompt["prompt"], str):
raise TypeError("Prompt text should be a string")
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt: def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
""" """
Parse a prompt for a decoder-only model and normalize it to a dictionary. Parse a prompt for a decoder-only model and normalize it to a dictionary.
...@@ -133,6 +146,8 @@ def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt: ...@@ -133,6 +146,8 @@ def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models") raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
_validate_prompt_dict(prompt)
if ( if (
"prompt" in prompt "prompt" in prompt
or "prompt_token_ids" in prompt or "prompt_token_ids" in prompt
...@@ -156,6 +171,8 @@ def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt: ...@@ -156,6 +171,8 @@ def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
return TokensPrompt(prompt_token_ids=prompt) return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
_validate_prompt_dict(prompt)
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
...@@ -178,6 +195,8 @@ def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt: ...@@ -178,6 +195,8 @@ def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
return TokensPrompt(prompt_token_ids=prompt) return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
_validate_prompt_dict(prompt)
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment