Unverified Commit 07351e08 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Feature] Warm up readonly multimodal processor during renderer startup (#40797)


Signed-off-by: default avatarChenguang ZHENG <645327136@qq.com>
Co-authored-by: default avatarOpenAI Codex <codex@openai.com>
parent 428b988c
......@@ -164,6 +164,58 @@ async def test_chat_error_non_stream():
await serving_chat.create_chat_completion(request)
@pytest.mark.asyncio
async def test_openai_chat_keeps_mm_cache_for_engine_execution():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
result = await serving_chat.render_chat_request(request)
assert isinstance(result, tuple)
assert (
serving_chat.openai_serving_render.preprocess_chat.call_args.kwargs[
"skip_mm_cache"
]
is False
)
@pytest.mark.asyncio
async def test_renderer_only_chat_request_skips_mm_cache():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
result = await serving_chat.openai_serving_render.render_chat_request(request)
assert result.token_ids == [1, 2, 3]
assert (
serving_chat.openai_serving_render.preprocess_chat.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
......
......@@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
......@@ -148,6 +148,66 @@ async def test_completion_error_non_stream():
await serving_completion.create_completion(request)
@pytest.mark.asyncio
async def test_openai_completion_keeps_mm_cache_for_engine_execution():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
serving_completion.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
result = await serving_completion.render_completion_request(request)
assert isinstance(result, list)
assert (
serving_completion.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is False
)
@pytest.mark.asyncio
async def test_renderer_only_completion_request_skips_mm_cache():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
serving_completion.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
result = await serving_completion.openai_serving_render.render_completion_request(
request
)
assert isinstance(result, list)
assert (
serving_completion.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
......
......@@ -12,7 +12,10 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
)
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logprobs import Logprob
......@@ -164,6 +167,36 @@ def _parse_sse_chunks(chunks: list[str]) -> list[Any]:
return parsed
@pytest.mark.asyncio
async def test_serve_tokens_skips_mm_cache_for_remote_engine_execution():
engine = _mock_engine()
async def mock_generate(*args, **kwargs):
yield _make_request_output(
"req-1", token_ids=[10], finish_reason="stop", finished=True
)
engine.generate = MagicMock(side_effect=mock_generate)
serving = _build_serving_tokens(engine)
request = GenerateRequest(
token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=1),
model=MODEL_NAME,
stream=False,
)
response = await serving.serve_tokens(request)
assert isinstance(response, GenerateResponse)
assert (
serving.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio
async def test_stream_basic():
"""Streaming returns SSE chunks with correct token_ids and ends with [DONE]."""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeChatRequest,
TokenizeCompletionRequest,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
model = MODEL_NAME
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_tokenization(engine: AsyncLLM) -> OpenAIServingTokenization:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
model_registry=models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
return OpenAIServingTokenization(
engine,
models,
openai_serving_render=serving_render,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
@pytest.mark.asyncio
async def test_tokenize_chat_skips_mm_cache_for_renderer_only_path():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = MagicMock()
serving = _build_serving_tokenization(mock_engine)
serving.openai_serving_render.preprocess_chat = AsyncMock(
return_value=(
[{"role": "user", "content": "Test"}],
[{"prompt_token_ids": [1, 2, 3]}],
)
)
request = TokenizeChatRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
response = await serving.create_tokenize(request, MagicMock(headers={}))
assert response.tokens == [1, 2, 3]
assert (
serving.openai_serving_render.preprocess_chat.call_args.kwargs["skip_mm_cache"]
is True
)
@pytest.mark.asyncio
async def test_tokenize_completion_skips_mm_cache_for_renderer_only_path():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = MagicMock()
serving = _build_serving_tokenization(mock_engine)
serving.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = TokenizeCompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
response = await serving.create_tokenize(request, MagicMock(headers={}))
assert response.tokens == [1, 2, 3]
assert (
serving.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
......@@ -34,6 +34,14 @@ def _make_renderer_mock(mm_limits: dict[str, int]) -> MagicMock:
mm_processor = MagicMock()
mm_processor.info.allowed_mm_limits = mm_limits
renderer.mm_processor = mm_processor
renderer._readonly_mm_processor = None
renderer._warmup_mm_processor = BaseRenderer._warmup_mm_processor.__get__(
renderer, BaseRenderer
)
renderer._clear_processor_cache = BaseRenderer._clear_processor_cache
renderer.clear_mm_cache = MagicMock()
renderer.model_config.max_model_len = 128
renderer.model_config.get_multimodal_config.return_value.limit_per_prompt = {}
return renderer
......@@ -109,3 +117,19 @@ class TestMmWarmupSkippedWhenNoProcessor:
BaseRenderer.warmup(renderer, ChatParams())
renderer.model_config.get_multimodal_config.assert_not_called()
class TestReadonlyMmWarmup:
"""Readonly MM processor warmup must mirror the render path behavior."""
def test_readonly_processor_apply_called_and_cache_cleared(self):
renderer = _make_renderer_mock({"image": 1})
readonly_mm_processor = MagicMock()
readonly_mm_processor.info.allowed_mm_limits = {"image": 1}
renderer._readonly_mm_processor = readonly_mm_processor
with patch("vllm.multimodal.processing.TimingContext", autospec=True):
BaseRenderer.warmup(renderer, ChatParams())
readonly_mm_processor.apply.assert_called_once()
readonly_mm_processor.cache.clear_cache.assert_called_once()
......@@ -136,7 +136,7 @@ class OpenAIServingRender:
"Beam search is not supported by the render endpoint"
)
result = await self.render_chat(request)
result = await self.render_chat(request, skip_mm_cache=True)
if isinstance(result, ErrorResponse):
return result
......@@ -184,6 +184,8 @@ class OpenAIServingRender:
async def render_chat(
self,
request: ChatCompletionRequest,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
"""Core preprocessing logic for chat requests (no model/engine check).
......@@ -252,7 +254,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
skip_mm_cache=True,
skip_mm_cache=skip_mm_cache,
reasoning_parser=self.reasoning_parser,
)
else:
......@@ -276,7 +278,7 @@ class OpenAIServingRender:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
result = await self.render_completion(request)
result = await self.render_completion(request, skip_mm_cache=True)
if isinstance(result, ErrorResponse):
return result
generate_requests: list[GenerateRequest] = []
......@@ -322,6 +324,8 @@ class OpenAIServingRender:
async def render_completion(
self,
request: CompletionRequest,
*,
skip_mm_cache: bool = False,
) -> list[EngineInput] | ErrorResponse:
"""Core preprocessing logic for completion requests (no model/engine check).
......@@ -344,7 +348,7 @@ class OpenAIServingRender:
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
skip_mm_cache=True,
skip_mm_cache=skip_mm_cache,
)
return engine_inputs
......
......@@ -198,6 +198,40 @@ class BaseRenderer(ABC, Generic[_T]):
if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True
@staticmethod
def _clear_processor_cache(
processor: "BaseMultiModalProcessor | None",
) -> None:
if processor is None:
return
processor_cache = processor.cache
if processor_cache is not None:
processor_cache.clear_cache()
def _warmup_mm_processor(
self,
processor: "BaseMultiModalProcessor",
*,
log_prefix: str,
) -> None:
from vllm.multimodal.processing import TimingContext
model_config = self.model_config
mm_config = model_config.get_multimodal_config()
mm_limits = {k: v for k, v in processor.info.allowed_mm_limits.items() if v > 0}
start_time = time.perf_counter()
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=dict.fromkeys(mm_limits, 1),
mm_options=mm_config.limit_per_prompt,
)
_ = processor.apply(processor_inputs, timing_ctx=TimingContext(enabled=False))
elapsed = time.perf_counter() - start_time
logger.info("%s warmup completed in %.3fs", log_prefix, elapsed)
def warmup(self, chat_params: ChatParams) -> None:
"""
Warm up this renderer to avoid first-request latency.
......@@ -221,35 +255,29 @@ class BaseRenderer(ABC, Generic[_T]):
logger.warning("Chat template warmup failed", exc_info=True)
if self.mm_processor:
from vllm.multimodal.processing import TimingContext
model_config = self.model_config
mm_config = model_config.get_multimodal_config()
processor = self.mm_processor
mm_limits = {
k: v for k, v in processor.info.allowed_mm_limits.items() if v > 0
}
try:
logger.debug("Warming up multi-modal processing...")
start_time = time.perf_counter()
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=dict.fromkeys(mm_limits, 1),
mm_options=mm_config.limit_per_prompt,
)
_ = processor.apply(
processor_inputs, timing_ctx=TimingContext(enabled=False)
self._warmup_mm_processor(
self.mm_processor,
log_prefix="Multi-modal",
)
elapsed = time.perf_counter() - start_time
logger.info("Multi-modal warmup completed in %.3fs", elapsed)
except Exception:
logger.warning("Multi-modal warmup failed")
finally:
self.clear_mm_cache()
if self._readonly_mm_processor is not None:
try:
logger.debug("Warming up readonly multi-modal processing...")
self._warmup_mm_processor(
self._readonly_mm_processor,
log_prefix="Readonly multi-modal",
)
except Exception:
logger.warning("Readonly multi-modal warmup failed")
finally:
self._clear_processor_cache(self._readonly_mm_processor)
async def clear_mm_cache_async(self) -> None:
"""Serialize clear_mm_cache through the shared executor to avoid
races with concurrent process_inputs on the mm_processor_cache."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment