"vscode:/vscode.git/clone" did not exist on "19f0d2579695e518c9bfc166544cf23775772bf8"
Unverified Commit f0a1c845 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Use new Renderer for Completions and Tokenize API (#32863)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8980001c
...@@ -60,9 +60,7 @@ def main(): ...@@ -60,9 +60,7 @@ def main():
completion = client.completions.create( completion = client.completions.create(
model=model_name, model=model_name,
# NOTE: The OpenAI client does not allow `None` as an input to prompt=None,
# `prompt`. Use an empty string if you have no text prompts.
prompt="",
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the # NOTE: The OpenAI client allows passing in extra JSON body via the
......
...@@ -22,7 +22,11 @@ def test_context_length_too_short(vllm_runner, image_assets, model): ...@@ -22,7 +22,11 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
with pytest.raises(ValueError, match="longer than the maximum model length"): with pytest.raises(ValueError, match="longer than the maximum model length"):
vllm_model = vllm_runner( vllm_model = vllm_runner(
model, model,
max_model_len=128, # LLaVA has a feature size of 576 # LLaVA has a feature size of 576
# For the HF processor to execute successfully but still
# failing the overall context length check, we need the
# max_model_len to at least contain all image tokens
max_model_len=579,
enforce_eager=True, enforce_eager=True,
load_format="dummy", load_format="dummy",
) )
......
...@@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test): ...@@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg, valid_msg,
] ]
sampling_params = SamplingParams(temperature=0, max_tokens=10) sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="longer than the maximum model length"): with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params) llm.chat(batch_1, sampling_params=sampling_params)
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params) outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2) assert len(outputs_2) == len(batch_2)
......
...@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse ...@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -57,6 +58,15 @@ class MockModelConfig: ...@@ -57,6 +58,15 @@ class MockModelConfig:
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=engine, engine_client=engine,
...@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
chat_template_content_format="auto", chat_template_content_format="auto",
) )
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs): async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts # return conversation, engine_prompts
return ( return (
...@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
[{"prompt_token_ids": [1, 2, 3]}], [{"prompt_token_ids": [1, 2, 3]}],
) )
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat) serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat return serving_chat
...@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
async def test_chat_error_non_stream(): async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)""" """test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -153,11 +150,11 @@ async def test_chat_error_non_stream(): ...@@ -153,11 +150,11 @@ async def test_chat_error_non_stream():
async def test_chat_error_stream(): async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)""" """test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import MagicMock
import pytest import pytest
...@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse ...@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: ...@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
engine_client=engine, engine_client=engine,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
) )
serving_completion = OpenAIServingCompletion( return OpenAIServingCompletion(
engine, engine,
models, models,
request_logger=None, request_logger=None,
) )
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs) def _build_renderer(model_config: MockModelConfig):
return serving_completion _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_error_non_stream(): async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)""" """test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine) serving_completion = _build_serving_completion(mock_engine)
...@@ -141,11 +136,11 @@ async def test_completion_error_non_stream(): ...@@ -141,11 +136,11 @@ async def test_completion_error_non_stream():
async def test_completion_error_stream(): async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)""" """test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine) serving_completion = _build_serving_completion(mock_engine)
......
...@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds( ...@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds(
# Test case: Single prompt embeds input # Test case: Single prompt embeds input
completion = await client_with_prompt_embeds.completions.create( completion = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds}, extra_body={"prompt_embeds": encoded_embeds},
...@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds( ...@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch completion with prompt_embeds # Test case: batch completion with prompt_embeds
completion = await client_with_prompt_embeds.completions.create( completion = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
...@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds( ...@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds(
# Test case: streaming with prompt_embeds # Test case: streaming with prompt_embeds
single_completion = await client_with_prompt_embeds.completions.create( single_completion = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds}, extra_body={"prompt_embeds": encoded_embeds},
...@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds( ...@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds(
stream = await client_with_prompt_embeds.completions.create( stream = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
stream=True, stream=True,
...@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds( ...@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch streaming with prompt_embeds # Test case: batch streaming with prompt_embeds
stream = await client_with_prompt_embeds.completions.create( stream = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
stream=True, stream=True,
...@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds( ...@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds(
) )
completion_embeds_only = await client_with_prompt_embeds.completions.create( completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds}, extra_body={"prompt_embeds": encoded_embeds},
...@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds( ...@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds(
# Test error case: invalid prompt_embeds # Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create( await client_with_prompt_embeds.completions.create(
prompt="", prompt=None,
model=model_name, model=model_name,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
...@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds( ...@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Logprobs using prompt_embeds # Test case: Logprobs using prompt_embeds
completion = await client_with_prompt_embeds.completions.create( completion = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
echo=False, echo=False,
...@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds( ...@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Log probs with batch completion and prompt_embeds # Test case: Log probs with batch completion and prompt_embeds
completion = await client_with_prompt_embeds.completions.create( completion = await client_with_prompt_embeds.completions.create(
model=model_name, model=model_name,
prompt="", # Add empty prompt as required parameter prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
echo=False, echo=False,
...@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error( ...@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error(
with pytest.raises(BadRequestError, match="not compatible"): with pytest.raises(BadRequestError, match="not compatible"):
await client_with_prompt_embeds.completions.create( await client_with_prompt_embeds.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt="", prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
......
...@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size ...@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size
are rejected before they can cause crashes during model inference. are rejected before they can cause crashes during model inference.
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
classes, not by CompletionRenderer or MediaIO classes. classes, not by MediaIO classes.
""" """
import pytest import pytest
......
...@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath ...@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.tokenizers import get_tokenizer from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -35,6 +36,7 @@ class MockModelConfig: ...@@ -35,6 +36,7 @@ class MockModelConfig:
"""Minimal mock ModelConfig for testing.""" """Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME model: str = MODEL_NAME
runner_type = "generate"
tokenizer: str = MODEL_NAME tokenizer: str = MODEL_NAME
trust_remote_code: bool = False trust_remote_code: bool = False
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
...@@ -85,15 +87,21 @@ def register_mock_resolver(): ...@@ -85,15 +87,21 @@ def register_mock_resolver():
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME] del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.fixture @pytest.fixture
def mock_serving_setup(): def mock_serving_setup():
"""Provides a mocked engine and serving completion instance.""" """Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False mock_engine.errored = False
tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
async def mock_add_lora_side_effect(lora_request: LoRARequest): async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs.""" """Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora": if lora_request.lora_name == "test-lora":
...@@ -118,6 +126,7 @@ def mock_serving_setup(): ...@@ -118,6 +126,7 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=mock_engine, engine_client=mock_engine,
...@@ -128,10 +137,6 @@ def mock_serving_setup(): ...@@ -128,10 +137,6 @@ def mock_serving_setup():
mock_engine, models, request_logger=None mock_engine, models, request_logger=None
) )
serving_completion._process_inputs = AsyncMock(
return_value=(MagicMock(name="engine_request"), {})
)
return mock_engine, serving_completion return mock_engine, serving_completion
......
...@@ -12,7 +12,7 @@ import regex as re ...@@ -12,7 +12,7 @@ import regex as re
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.renderer import CompletionRenderer from vllm.renderers.embed_utils import safe_load_prompt_embeds
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -30,7 +30,7 @@ async def test_empty_prompt(): ...@@ -30,7 +30,7 @@ async def test_empty_prompt():
): ):
await client.completions.create( await client.completions.create(
model=model_name, model=model_name,
prompt="", prompt=None,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body={"prompt_embeds": []}, extra_body={"prompt_embeds": []},
...@@ -63,7 +63,6 @@ def test_load_prompt_embeds( ...@@ -63,7 +63,6 @@ def test_load_prompt_embeds(
): ):
model_config = Mock(spec=ModelConfig) model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True model_config.enable_prompt_embeds = True
renderer = CompletionRenderer(model_config, tokenizer=None)
# construct arbitrary tensors of various dtypes, layouts, and sizes. # construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user # We need to check against different layouts to make sure that if a user
...@@ -89,9 +88,7 @@ def test_load_prompt_embeds( ...@@ -89,9 +88,7 @@ def test_load_prompt_embeds(
buffer.seek(0) buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue()) encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor) loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu" assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided assert loaded_tensor.layout == torch.strided
torch.testing.assert_close( torch.testing.assert_close(
...@@ -105,7 +102,6 @@ def test_load_prompt_embeds( ...@@ -105,7 +102,6 @@ def test_load_prompt_embeds(
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int): def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig) model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False model_config.enable_prompt_embeds = False
renderer = CompletionRenderer(model_config, tokenizer=None)
tensor = torch.randn((seq_len, hidden_size), dtype=dtype) tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
...@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in ...@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in
encoded_tensor = pybase64.b64encode(buffer.getvalue()) encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"): with pytest.raises(ValueError, match="--enable-prompt-embeds"):
renderer.load_prompt_embeds(encoded_tensor) safe_load_prompt_embeds(model_config, encoded_tensor)
...@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
request_logger=None, request_logger=None,
) )
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat return serving_chat
...@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): ...@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
resp = await serving_chat.create_chat_completion(req) resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse) assert isinstance(resp, ErrorResponse)
assert "max_tokens" in resp.error.message assert "context length is only" in resp.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): ...@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
resp = await serving_chat.create_chat_completion(req) resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse) assert isinstance(resp, ErrorResponse)
assert "maximum context length" in resp.error.message assert "context length is only" in resp.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
orig_render_chat_request = serving_chat.render_chat_request
captured_prompts = []
async def render_chat_request(request):
result = await orig_render_chat_request(request)
assert isinstance(result, tuple)
conversation, engine_prompts = result
captured_prompts.extend(engine_prompts)
return result
serving_chat.render_chat_request = render_chat_request
# Test cache_salt # Test cache_salt
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
...@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set # By default, cache_salt in the engine prompt is not set
with suppress(Exception): with suppress(Exception):
await serving_chat.create_chat_completion(req) await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt assert len(captured_prompts) == 1
assert "cache_salt" not in captured_prompts[0]
captured_prompts.clear()
# Test with certain cache_salt # Test with certain cache_salt
req.cache_salt = "test_salt" req.cache_salt = "test_salt"
with suppress(Exception): with suppress(Exception):
await serving_chat.create_chat_completion(req) await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt" assert len(captured_prompts) == 1
assert captured_prompts[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony: ...@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony:
@pytest.fixture() @pytest.fixture()
def mock_engine(self) -> AsyncLLM: def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine return mock_engine
@pytest.fixture() @pytest.fixture()
...@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser(): ...@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser():
"""Test that tool_choice='required' or named tool without tool_parser """Test that tool_choice='required' or named tool without tool_parser
returns an appropriate error message.""" returns an appropriate error message."""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=mock_engine, engine_client=mock_engine,
......
...@@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): ...@@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
assert response["usage"]["prompt_tokens"] == truncation_size assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_zero_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 0
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size,
}
response = await client.post(path="embeddings", cast_to=object, body={**kwargs})
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI): async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1 truncation_size = max_model_len + 1
......
...@@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): ...@@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
server.url_for("classify"), server.url_for("classify"),
json={"model": model_name, "input": []}, json={"model": model_name, "input": []},
) )
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json())
assert output.object == "list" error = classification_response.json()
assert isinstance(output.data, list) assert classification_response.status_code == 400
assert len(output.data) == 0 assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
......
...@@ -247,7 +247,7 @@ class TestModel: ...@@ -247,7 +247,7 @@ class TestModel:
}, },
) )
assert score_response.status_code == 400 assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in score_response.text assert "Please request a smaller truncation size." in score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]):
queries = "What is the capital of France?" queries = "What is the capital of France?"
......
...@@ -96,7 +96,7 @@ def test_gemma_multimodal( ...@@ -96,7 +96,7 @@ def test_gemma_multimodal(
dtype="bfloat16", dtype="bfloat16",
) as vllm_model: ) as vllm_model:
llm = vllm_model.get_llm() llm = vllm_model.get_llm()
prompts = llm.preprocess_chat(messages) prompts = llm._preprocess_chat([messages])
result = llm.classify(prompts) result = llm.classify(prompts)
assert result[0].outputs.probs[0] > 0.95 assert result[0].outputs.probs[0] > 0.95
......
...@@ -29,7 +29,8 @@ def test_smaller_truncation_size( ...@@ -29,7 +29,8 @@ def test_smaller_truncation_size(
model_name, runner="pooling", max_model_len=max_model_len model_name, runner="pooling", max_model_len=max_model_len
) as vllm_model: ) as vllm_model:
vllm_output = vllm_model.llm.embed( vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
) )
prompt_tokens = vllm_output[0].prompt_token_ids prompt_tokens = vllm_output[0].prompt_token_ids
...@@ -44,7 +45,8 @@ def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input ...@@ -44,7 +45,8 @@ def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input
model_name, runner="pooling", max_model_len=max_model_len model_name, runner="pooling", max_model_len=max_model_len
) as vllm_model: ) as vllm_model:
vllm_output = vllm_model.llm.embed( vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
) )
prompt_tokens = vllm_output[0].prompt_token_ids prompt_tokens = vllm_output[0].prompt_token_ids
...@@ -64,7 +66,8 @@ def test_bigger_truncation_size( ...@@ -64,7 +66,8 @@ def test_bigger_truncation_size(
) as vllm_model, ) as vllm_model,
): ):
llm_output = vllm_model.llm.embed( llm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
) )
assert ( assert (
......
...@@ -187,7 +187,10 @@ def mteb_test_embed_models( ...@@ -187,7 +187,10 @@ def mteb_test_embed_models(
head_dtype = model_config.head_dtype head_dtype = model_config.head_dtype
# Test embedding_size, isnan and whether to use normalize # Test embedding_size, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) vllm_outputs = vllm_model.embed(
example_prompts,
tokenization_kwargs=dict(truncate_prompt_tokens=-1),
)
outputs_tensor = torch.tensor(vllm_outputs) outputs_tensor = torch.tensor(vllm_outputs)
assert not torch.any(torch.isnan(outputs_tensor)) assert not torch.any(torch.isnan(outputs_tensor))
embedding_size = model_config.embedding_size embedding_size = model_config.embedding_size
......
...@@ -79,9 +79,9 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin): ...@@ -79,9 +79,9 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
outputs = self.llm.score( outputs = self.llm.score(
queries, queries,
corpus, corpus,
truncate_prompt_tokens=-1,
use_tqdm=False, use_tqdm=False,
chat_template=self.chat_template, chat_template=self.chat_template,
tokenization_kwargs={"truncate_prompt_tokens": -1},
) )
scores = np.array(outputs) scores = np.array(outputs)
scores = scores[np.argsort(r)] scores = scores[np.argsort(r)]
......
...@@ -3,26 +3,39 @@ ...@@ -3,26 +3,39 @@
import io import io
from dataclasses import dataclass from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock from typing import Any
from unittest.mock import AsyncMock
import pybase64 import pybase64
import pytest import pytest
import torch import torch
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt from vllm.inputs.data import is_embeds_prompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100 max_model_len: int = 100
encoder_config: dict | None = None tokenizer_revision = None
tokenizer_mode = "auto"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
class MockTokenizerResult:
def __init__(self, input_ids):
self.input_ids = input_ids
@pytest.fixture @pytest.fixture
...@@ -30,35 +43,80 @@ def mock_model_config(): ...@@ -30,35 +43,80 @@ def mock_model_config():
return MockModelConfig() return MockModelConfig()
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
return tokenizer
@pytest.fixture @pytest.fixture
def mock_async_tokenizer(): def mock_async_tokenizer():
async_tokenizer = AsyncMock() return AsyncMock()
return async_tokenizer
@pytest.fixture @pytest.fixture
def renderer(mock_model_config, mock_tokenizer): def renderer(mock_model_config):
return CompletionRenderer( _, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
model_config=mock_model_config,
tokenizer=mock_tokenizer, return HfRenderer(
async_tokenizer_pool={}, mock_model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
class TestRenderPrompt: class TestValidatePrompt:
"""Test Category A: Basic Functionality Tests""" STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self, renderer):
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
def test_invalid_type(self, renderer):
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, renderer, string_input: str):
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, renderer, token_input: list[int]):
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, renderer, inputs_slice: slice):
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_input(self, renderer): async def test_token_input(self, renderer):
tokens = [101, 7592, 2088] tokens = [101, 7592, 2088]
results = await renderer.render_prompt( prompts = await renderer.render_completions_async(tokens)
prompt_or_prompts=tokens, config=RenderConfig(max_length=100) results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -67,8 +125,10 @@ class TestRenderPrompt: ...@@ -67,8 +125,10 @@ class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_list_input(self, renderer): async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt( prompts = await renderer.render_completions_async(token_lists)
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100) results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 3 assert len(results) == 3
...@@ -78,43 +138,49 @@ class TestRenderPrompt: ...@@ -78,43 +138,49 @@ class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer): async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer renderer._async_tokenizer = mock_async_tokenizer
results = await renderer.render_prompt( prompts = await renderer.render_completions_async("Hello world")
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088] assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.assert_called_once() mock_async_tokenizer.encode.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer): async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer renderer._async_tokenizer = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"] text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt( prompts = await renderer.render_completions_async(text_list_input)
prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100) results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 3 assert len(results) == 3
for result in results: for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088] assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.call_count == 3 assert mock_async_tokenizer.encode.call_count == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer): async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer renderer._async_tokenizer = mock_async_tokenizer
results = await renderer.render_prompt( prompts = await renderer.render_completions_async("Hello world")
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.encode.call_args
assert ( assert (
"truncation" not in call_args.kwargs "truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False or call_args.kwargs["truncation"] is False
...@@ -122,46 +188,58 @@ class TestRenderPrompt: ...@@ -122,46 +188,58 @@ class TestRenderPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer): async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult( mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
[101, 7592, 2088] renderer._async_tokenizer = mock_async_tokenizer
) # Truncated
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
results = await renderer.render_prompt( prompts,
prompt_or_prompts="Hello world", TokenizeParams(
config=RenderConfig(max_length=100, truncate_prompt_tokens=50), max_total_tokens=200,
truncate_prompt_tokens=50,
),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50 assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer): async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len # Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult( mock_async_tokenizer.encode.return_value = [
[101, 7592, 2088] 101,
) # Truncated to max_model_len 7592,
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer 2088,
] # Truncated to max_model_len
results = await renderer.render_prompt( renderer._async_tokenizer = mock_async_tokenizer
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=200, truncate_prompt_tokens=-1), prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=-1,
),
) )
assert len(results) == 1 assert len(results) == 1
call_args = mock_async_tokenizer.call_args call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 100 # model's max_model_len assert call_args.kwargs["max_length"] == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer): async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements # Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
results = await renderer.render_prompt( prompts = await renderer.render_completions_async(long_tokens)
prompt_or_prompts=long_tokens, results = await renderer.tokenize_prompts_async(
config=RenderConfig(max_length=100, truncate_prompt_tokens=5), prompts,
TokenizeParams(
max_total_tokens=100,
truncate_prompt_tokens=5,
),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -172,20 +250,27 @@ class TestRenderPrompt: ...@@ -172,20 +250,27 @@ class TestRenderPrompt:
async def test_max_length_exceeded(self, renderer): async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100 long_tokens = list(range(150)) # Exceeds max_model_len=100
with pytest.raises(ValueError, match="maximum context length"): prompts = await renderer.render_completions_async(long_tokens)
await renderer.render_prompt(
prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100) with pytest.raises(ValueError, match="context length is only"):
await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config): async def test_no_tokenizer_for_text(self, renderer):
renderer_no_tokenizer = CompletionRenderer( renderer_no_tokenizer = HfRenderer.from_config(
model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={} MockModelConfig(skip_tokenizer_init=True),
tokenizer_kwargs={},
) )
with pytest.raises(ValueError, match="No tokenizer available"): prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
await renderer_no_tokenizer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
) )
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -196,12 +281,16 @@ class TestRenderPrompt: ...@@ -196,12 +281,16 @@ class TestRenderPrompt:
# use the async tokenizer to decode and include the original text # use the async tokenizer to decode and include the original text
# in the returned prompt object. # in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer renderer._async_tokenizer = mock_async_tokenizer
tokens = [1, 2, 3, 4] tokens = [1, 2, 3, 4]
results = await renderer.render_prompt( prompts = await renderer.render_completions_async(tokens)
prompt_or_prompts=tokens, results = await renderer.tokenize_prompts_async(
config=RenderConfig(needs_detokenization=True), prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
needs_detokenization=True,
),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -224,15 +313,15 @@ class TestRenderEmbedPrompt: ...@@ -224,15 +313,15 @@ class TestRenderEmbedPrompt:
test_tensor = torch.randn(10, 768, dtype=torch.float32) test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes, results = await renderer.tokenize_prompts_async(
config=RenderConfig(cache_salt="test_salt"), prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
) )
assert len(results) == 1 assert len(results) == 1
assert is_embeds_prompt(results[0]) assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor) assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert results[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer): async def test_multiple_prompt_embeds(self, renderer):
...@@ -243,9 +332,12 @@ class TestRenderEmbedPrompt: ...@@ -243,9 +332,12 @@ class TestRenderEmbedPrompt:
] ]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors] embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(
prompt_embeds=embed_bytes_list, prompt_embeds=embed_bytes_list
config=RenderConfig(), )
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
) )
assert len(results) == 2 assert len(results) == 2
...@@ -259,9 +351,13 @@ class TestRenderEmbedPrompt: ...@@ -259,9 +351,13 @@ class TestRenderEmbedPrompt:
test_tensor = torch.randn(20, 768, dtype=torch.float32) test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes, results = await renderer.tokenize_prompts_async(
config=RenderConfig(truncate_prompt_tokens=10), prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
truncate_prompt_tokens=10,
),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -278,9 +374,10 @@ class TestRenderEmbedPrompt: ...@@ -278,9 +374,10 @@ class TestRenderEmbedPrompt:
test_tensor = torch.randn(5, 256, dtype=dtype) test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes, results = await renderer.tokenize_prompts_async(
config=RenderConfig(), prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -292,9 +389,10 @@ class TestRenderEmbedPrompt: ...@@ -292,9 +389,10 @@ class TestRenderEmbedPrompt:
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes, results = await renderer.tokenize_prompts_async(
config=RenderConfig(), prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
) )
assert len(results) == 1 assert len(results) == 1
...@@ -304,17 +402,20 @@ class TestRenderEmbedPrompt: ...@@ -304,17 +402,20 @@ class TestRenderEmbedPrompt:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization # Set up text tokenization
mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103]) mock_async_tokenizer.encode.return_value = [101, 102, 103]
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer renderer._async_tokenizer = mock_async_tokenizer
# Create embed # Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32) test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor) embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds( prompts = await renderer.render_completions_async(
prompt_or_prompts="Hello world", "Hello world",
prompt_embeds=embed_bytes, prompt_embeds=embed_bytes,
config=RenderConfig(), )
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
) )
assert len(results) == 2 assert len(results) == 2
......
...@@ -9,6 +9,7 @@ import pytest ...@@ -9,6 +9,7 @@ import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -27,7 +28,7 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(): ...@@ -27,7 +28,7 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([]) task = mock_renderer.render_messages_async([], ChatParams())
# Ensure the event loop is not blocked # Ensure the event loop is not blocked
blocked_count = 0 blocked_count = 0
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations. out-of-bounds memory writes during to_dense() operations.
""" """
...@@ -13,8 +11,24 @@ import io ...@@ -13,8 +11,24 @@ import io
import pytest import pytest
import torch import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO
from vllm.renderers.embed_utils import safe_load_prompt_embeds
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)
def _encode_tensor(tensor: torch.Tensor) -> bytes: def _encode_tensor(tensor: torch.Tensor) -> bytes:
...@@ -63,15 +77,12 @@ class TestPromptEmbedsValidation: ...@@ -63,15 +77,12 @@ class TestPromptEmbedsValidation:
def test_valid_dense_tensor_accepted(self, model_config): def test_valid_dense_tensor_accepted(self, model_config):
"""Baseline: Valid dense tensors should work normally.""" """Baseline: Valid dense tensors should work normally."""
renderer = CompletionRenderer(model_config)
valid_tensor = _create_valid_dense_tensor() valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor) encoded = _encode_tensor(valid_tensor)
# Should not raise any exception # Should not raise any exception
result = renderer.load_prompt_embeds(encoded) result = safe_load_prompt_embeds(model_config, encoded)
assert len(result) == 1 assert result.shape == valid_tensor.shape
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self): def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully.""" """Baseline: Valid sparse tensors should load successfully."""
...@@ -86,14 +97,12 @@ class TestPromptEmbedsValidation: ...@@ -86,14 +97,12 @@ class TestPromptEmbedsValidation:
def test_malicious_sparse_tensor_rejected(self, model_config): def test_malicious_sparse_tensor_rejected(self, model_config):
"""Security: Malicious sparse tensors should be rejected.""" """Security: Malicious sparse tensors should be rejected."""
renderer = CompletionRenderer(model_config)
malicious_tensor = _create_malicious_sparse_tensor() malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor) encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor # Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info: with pytest.raises((RuntimeError, ValueError)) as exc_info:
renderer.load_prompt_embeds(encoded) safe_load_prompt_embeds(model_config, encoded)
# Error should indicate sparse tensor validation failure # Error should indicate sparse tensor validation failure
error_msg = str(exc_info.value).lower() error_msg = str(exc_info.value).lower()
...@@ -101,8 +110,6 @@ class TestPromptEmbedsValidation: ...@@ -101,8 +110,6 @@ class TestPromptEmbedsValidation:
def test_extremely_large_indices_rejected(self, model_config): def test_extremely_large_indices_rejected(self, model_config):
"""Security: Sparse tensors with extremely large indices should be rejected.""" """Security: Sparse tensors with extremely large indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with indices far beyond reasonable bounds # Create tensor with indices far beyond reasonable bounds
indices = torch.tensor([[999999], [999999]]) indices = torch.tensor([[999999], [999999]])
values = torch.tensor([1.0]) values = torch.tensor([1.0])
...@@ -114,12 +121,10 @@ class TestPromptEmbedsValidation: ...@@ -114,12 +121,10 @@ class TestPromptEmbedsValidation:
encoded = _encode_tensor(malicious_tensor) encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)): with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded) safe_load_prompt_embeds(model_config, encoded)
def test_negative_indices_rejected(self, model_config): def test_negative_indices_rejected(self, model_config):
"""Security: Sparse tensors with negative indices should be rejected.""" """Security: Sparse tensors with negative indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with negative indices # Create tensor with negative indices
indices = torch.tensor([[-1], [-1]]) indices = torch.tensor([[-1], [-1]])
values = torch.tensor([1.0]) values = torch.tensor([1.0])
...@@ -131,7 +136,7 @@ class TestPromptEmbedsValidation: ...@@ -131,7 +136,7 @@ class TestPromptEmbedsValidation:
encoded = _encode_tensor(malicious_tensor) encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)): with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded) safe_load_prompt_embeds(model_config, encoded)
class TestImageEmbedsValidation: class TestImageEmbedsValidation:
...@@ -253,14 +258,12 @@ class TestSparseTensorValidationIntegration: ...@@ -253,14 +258,12 @@ class TestSparseTensorValidationIntegration:
3. Sends to /v1/completions with prompt_embeds parameter 3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs 4. Server should reject before memory corruption occurs
""" """
renderer = CompletionRenderer(model_config)
# Step 1-2: Attacker creates malicious payload # Step 1-2: Attacker creates malicious payload
attack_payload = _encode_tensor(_create_malicious_sparse_tensor()) attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
# Step 3-4: Server processes and should reject # Step 3-4: Server processes and should reject
with pytest.raises((RuntimeError, ValueError)): with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(attack_payload) safe_load_prompt_embeds(model_config, attack_payload)
def test_attack_scenario_chat_api_image(self): def test_attack_scenario_chat_api_image(self):
""" """
...@@ -285,57 +288,3 @@ class TestSparseTensorValidationIntegration: ...@@ -285,57 +288,3 @@ class TestSparseTensorValidationIntegration:
with pytest.raises((RuntimeError, ValueError)): with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8")) io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_multiple_valid_embeddings_in_batch(self, model_config):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer = CompletionRenderer(model_config)
valid_tensors = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
]
# Should process all without error
result = renderer.load_prompt_embeds(valid_tensors)
assert len(result) == 3
def test_mixed_valid_and_malicious_rejected(self, model_config):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer = CompletionRenderer(model_config)
mixed_batch = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
_encode_tensor(_create_valid_dense_tensor()),
]
# Should fail on the malicious tensor
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(mixed_batch)
# Pytest fixtures
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment