Unverified Commit 037a6487 authored by vanshil shah's avatar vanshil shah Committed by GitHub
Browse files

apply _validate_input to MistralTokenizer token-id chat prompts (#32448)


Signed-off-by: default avatarVanshil Shah <vanshilshah@gmail.com>
parent 5a3050a0
...@@ -731,6 +731,101 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -731,6 +731,101 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 5 assert mock_engine.generate.call_args.args[1].max_tokens == 5
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module):
"""Regression test: when the Mistral tokenizer path returns token IDs
directly, we must still apply input length + max_tokens validation.
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
# Only used for logging/validation error messages.
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
# Patch the OpenAI engine serving module to treat our dummy tokenizer
# as a MistralTokenizer. This forces the code path where chat template
# rendering can return a list[int] (token IDs).
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window.
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(95))
)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
max_tokens=10,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "max_tokens" in resp.error.message
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(
monkeypatch_module,
):
"""Regression test: MistralTokenizer token-id prompts must still enforce
the max context length for the input itself (token_num >= max_model_len).
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest).
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(mock_engine.model_config.max_model_len))
)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
max_tokens=1,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "maximum context length" in resp.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config(): async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig() mock_model_config = MockModelConfig()
......
...@@ -1277,9 +1277,11 @@ class OpenAIServing: ...@@ -1277,9 +1277,11 @@ class OpenAIServing:
assert is_list_of(request_prompt, int), ( assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids" "Prompt has to be either a string or a list of token ids"
) )
prompt_inputs = TokensPrompt( input_text = tokenizer.decode(request_prompt)
prompt=tokenizer.decode(request_prompt), prompt_inputs = self._validate_input(
prompt_token_ids=request_prompt, request=request,
input_ids=request_prompt,
input_text=input_text,
) )
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"]) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment