"vscode:/vscode.git/clone" did not exist on "ea3370b428e1b192d29a4451d439f4ed0895f1f3"
Unverified Commit 07351e08 authored by Chenguang Zheng's avatar Chenguang Zheng Committed by GitHub
Browse files

[Feature] Warm up readonly multimodal processor during renderer startup (#40797)


Signed-off-by: default avatarChenguang ZHENG <645327136@qq.com>
Co-authored-by: default avatarOpenAI Codex <codex@openai.com>
parent 428b988c
...@@ -164,6 +164,58 @@ async def test_chat_error_non_stream(): ...@@ -164,6 +164,58 @@ async def test_chat_error_non_stream():
await serving_chat.create_chat_completion(request) await serving_chat.create_chat_completion(request)
@pytest.mark.asyncio
async def test_openai_chat_keeps_mm_cache_for_engine_execution():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
result = await serving_chat.render_chat_request(request)
assert isinstance(result, tuple)
assert (
serving_chat.openai_serving_render.preprocess_chat.call_args.kwargs[
"skip_mm_cache"
]
is False
)
@pytest.mark.asyncio
async def test_renderer_only_chat_request_skips_mm_cache():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
result = await serving_chat.openai_serving_render.render_chat_request(request)
assert result.token_ids == [1, 2, 3]
assert (
serving_chat.openai_serving_render.preprocess_chat.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_error_stream(): async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)""" """test finish_reason='error' returns 500 InternalServerError (streaming)"""
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
...@@ -148,6 +148,66 @@ async def test_completion_error_non_stream(): ...@@ -148,6 +148,66 @@ async def test_completion_error_non_stream():
await serving_completion.create_completion(request) await serving_completion.create_completion(request)
@pytest.mark.asyncio
async def test_openai_completion_keeps_mm_cache_for_engine_execution():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
serving_completion.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
result = await serving_completion.render_completion_request(request)
assert isinstance(result, list)
assert (
serving_completion.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is False
)
@pytest.mark.asyncio
async def test_renderer_only_completion_request_skips_mm_cache():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
serving_completion.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
result = await serving_completion.openai_serving_render.render_completion_request(
request
)
assert isinstance(result, list)
assert (
serving_completion.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_error_stream(): async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)""" """test finish_reason='error' returns 500 InternalServerError (streaming)"""
......
...@@ -12,7 +12,10 @@ from vllm.config.multimodal import MultiModalConfig ...@@ -12,7 +12,10 @@ from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.engine.protocol import StreamOptions from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
)
from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
...@@ -164,6 +167,36 @@ def _parse_sse_chunks(chunks: list[str]) -> list[Any]: ...@@ -164,6 +167,36 @@ def _parse_sse_chunks(chunks: list[str]) -> list[Any]:
return parsed return parsed
@pytest.mark.asyncio
async def test_serve_tokens_skips_mm_cache_for_remote_engine_execution():
engine = _mock_engine()
async def mock_generate(*args, **kwargs):
yield _make_request_output(
"req-1", token_ids=[10], finish_reason="stop", finished=True
)
engine.generate = MagicMock(side_effect=mock_generate)
serving = _build_serving_tokens(engine)
request = GenerateRequest(
token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=1),
model=MODEL_NAME,
stream=False,
)
response = await serving.serve_tokens(request)
assert isinstance(response, GenerateResponse)
assert (
serving.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_basic(): async def test_stream_basic():
"""Streaming returns SSE chunks with correct token_ids and ends with [DONE].""" """Streaming returns SSE chunks with correct token_ids and ends with [DONE]."""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeChatRequest,
TokenizeCompletionRequest,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
model = MODEL_NAME
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_tokenization(engine: AsyncLLM) -> OpenAIServingTokenization:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_render = OpenAIServingRender(
model_config=engine.model_config,
renderer=engine.renderer,
model_registry=models.registry,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
return OpenAIServingTokenization(
engine,
models,
openai_serving_render=serving_render,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
@pytest.mark.asyncio
async def test_tokenize_chat_skips_mm_cache_for_renderer_only_path():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = MagicMock()
serving = _build_serving_tokenization(mock_engine)
serving.openai_serving_render.preprocess_chat = AsyncMock(
return_value=(
[{"role": "user", "content": "Test"}],
[{"prompt_token_ids": [1, 2, 3]}],
)
)
request = TokenizeChatRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
)
response = await serving.create_tokenize(request, MagicMock(headers={}))
assert response.tokens == [1, 2, 3]
assert (
serving.openai_serving_render.preprocess_chat.call_args.kwargs["skip_mm_cache"]
is True
)
@pytest.mark.asyncio
async def test_tokenize_completion_skips_mm_cache_for_renderer_only_path():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.renderer = MagicMock()
serving = _build_serving_tokenization(mock_engine)
serving.openai_serving_render.preprocess_completion = AsyncMock(
return_value=[{"prompt_token_ids": [1, 2, 3]}]
)
request = TokenizeCompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
)
response = await serving.create_tokenize(request, MagicMock(headers={}))
assert response.tokens == [1, 2, 3]
assert (
serving.openai_serving_render.preprocess_completion.call_args.kwargs[
"skip_mm_cache"
]
is True
)
...@@ -34,6 +34,14 @@ def _make_renderer_mock(mm_limits: dict[str, int]) -> MagicMock: ...@@ -34,6 +34,14 @@ def _make_renderer_mock(mm_limits: dict[str, int]) -> MagicMock:
mm_processor = MagicMock() mm_processor = MagicMock()
mm_processor.info.allowed_mm_limits = mm_limits mm_processor.info.allowed_mm_limits = mm_limits
renderer.mm_processor = mm_processor renderer.mm_processor = mm_processor
renderer._readonly_mm_processor = None
renderer._warmup_mm_processor = BaseRenderer._warmup_mm_processor.__get__(
renderer, BaseRenderer
)
renderer._clear_processor_cache = BaseRenderer._clear_processor_cache
renderer.clear_mm_cache = MagicMock()
renderer.model_config.max_model_len = 128
renderer.model_config.get_multimodal_config.return_value.limit_per_prompt = {}
return renderer return renderer
...@@ -109,3 +117,19 @@ class TestMmWarmupSkippedWhenNoProcessor: ...@@ -109,3 +117,19 @@ class TestMmWarmupSkippedWhenNoProcessor:
BaseRenderer.warmup(renderer, ChatParams()) BaseRenderer.warmup(renderer, ChatParams())
renderer.model_config.get_multimodal_config.assert_not_called() renderer.model_config.get_multimodal_config.assert_not_called()
class TestReadonlyMmWarmup:
"""Readonly MM processor warmup must mirror the render path behavior."""
def test_readonly_processor_apply_called_and_cache_cleared(self):
renderer = _make_renderer_mock({"image": 1})
readonly_mm_processor = MagicMock()
readonly_mm_processor.info.allowed_mm_limits = {"image": 1}
renderer._readonly_mm_processor = readonly_mm_processor
with patch("vllm.multimodal.processing.TimingContext", autospec=True):
BaseRenderer.warmup(renderer, ChatParams())
readonly_mm_processor.apply.assert_called_once()
readonly_mm_processor.cache.clear_cache.assert_called_once()
...@@ -136,7 +136,7 @@ class OpenAIServingRender: ...@@ -136,7 +136,7 @@ class OpenAIServingRender:
"Beam search is not supported by the render endpoint" "Beam search is not supported by the render endpoint"
) )
result = await self.render_chat(request) result = await self.render_chat(request, skip_mm_cache=True)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
...@@ -184,6 +184,8 @@ class OpenAIServingRender: ...@@ -184,6 +184,8 @@ class OpenAIServingRender:
async def render_chat( async def render_chat(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse: ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
"""Core preprocessing logic for chat requests (no model/engine check). """Core preprocessing logic for chat requests (no model/engine check).
...@@ -252,7 +254,7 @@ class OpenAIServingRender: ...@@ -252,7 +254,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs, default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
skip_mm_cache=True, skip_mm_cache=skip_mm_cache,
reasoning_parser=self.reasoning_parser, reasoning_parser=self.reasoning_parser,
) )
else: else:
...@@ -276,7 +278,7 @@ class OpenAIServingRender: ...@@ -276,7 +278,7 @@ class OpenAIServingRender:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
result = await self.render_completion(request) result = await self.render_completion(request, skip_mm_cache=True)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
generate_requests: list[GenerateRequest] = [] generate_requests: list[GenerateRequest] = []
...@@ -322,6 +324,8 @@ class OpenAIServingRender: ...@@ -322,6 +324,8 @@ class OpenAIServingRender:
async def render_completion( async def render_completion(
self, self,
request: CompletionRequest, request: CompletionRequest,
*,
skip_mm_cache: bool = False,
) -> list[EngineInput] | ErrorResponse: ) -> list[EngineInput] | ErrorResponse:
"""Core preprocessing logic for completion requests (no model/engine check). """Core preprocessing logic for completion requests (no model/engine check).
...@@ -344,7 +348,7 @@ class OpenAIServingRender: ...@@ -344,7 +348,7 @@ class OpenAIServingRender:
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
skip_mm_cache=True, skip_mm_cache=skip_mm_cache,
) )
return engine_inputs return engine_inputs
......
...@@ -198,6 +198,40 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -198,6 +198,40 @@ class BaseRenderer(ABC, Generic[_T]):
if self._mm_cache_stats is not None: if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True self._mm_cache_stats.reset = True
@staticmethod
def _clear_processor_cache(
processor: "BaseMultiModalProcessor | None",
) -> None:
if processor is None:
return
processor_cache = processor.cache
if processor_cache is not None:
processor_cache.clear_cache()
def _warmup_mm_processor(
self,
processor: "BaseMultiModalProcessor",
*,
log_prefix: str,
) -> None:
from vllm.multimodal.processing import TimingContext
model_config = self.model_config
mm_config = model_config.get_multimodal_config()
mm_limits = {k: v for k, v in processor.info.allowed_mm_limits.items() if v > 0}
start_time = time.perf_counter()
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=dict.fromkeys(mm_limits, 1),
mm_options=mm_config.limit_per_prompt,
)
_ = processor.apply(processor_inputs, timing_ctx=TimingContext(enabled=False))
elapsed = time.perf_counter() - start_time
logger.info("%s warmup completed in %.3fs", log_prefix, elapsed)
def warmup(self, chat_params: ChatParams) -> None: def warmup(self, chat_params: ChatParams) -> None:
""" """
Warm up this renderer to avoid first-request latency. Warm up this renderer to avoid first-request latency.
...@@ -221,35 +255,29 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -221,35 +255,29 @@ class BaseRenderer(ABC, Generic[_T]):
logger.warning("Chat template warmup failed", exc_info=True) logger.warning("Chat template warmup failed", exc_info=True)
if self.mm_processor: if self.mm_processor:
from vllm.multimodal.processing import TimingContext
model_config = self.model_config
mm_config = model_config.get_multimodal_config()
processor = self.mm_processor
mm_limits = {
k: v for k, v in processor.info.allowed_mm_limits.items() if v > 0
}
try: try:
logger.debug("Warming up multi-modal processing...") logger.debug("Warming up multi-modal processing...")
start_time = time.perf_counter() self._warmup_mm_processor(
self.mm_processor,
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( log_prefix="Multi-modal",
seq_len=model_config.max_model_len,
mm_counts=dict.fromkeys(mm_limits, 1),
mm_options=mm_config.limit_per_prompt,
)
_ = processor.apply(
processor_inputs, timing_ctx=TimingContext(enabled=False)
) )
elapsed = time.perf_counter() - start_time
logger.info("Multi-modal warmup completed in %.3fs", elapsed)
except Exception: except Exception:
logger.warning("Multi-modal warmup failed") logger.warning("Multi-modal warmup failed")
finally: finally:
self.clear_mm_cache() self.clear_mm_cache()
if self._readonly_mm_processor is not None:
try:
logger.debug("Warming up readonly multi-modal processing...")
self._warmup_mm_processor(
self._readonly_mm_processor,
log_prefix="Readonly multi-modal",
)
except Exception:
logger.warning("Readonly multi-modal warmup failed")
finally:
self._clear_processor_cache(self._readonly_mm_processor)
async def clear_mm_cache_async(self) -> None: async def clear_mm_cache_async(self) -> None:
"""Serialize clear_mm_cache through the shared executor to avoid """Serialize clear_mm_cache through the shared executor to avoid
races with concurrent process_inputs on the mm_processor_cache.""" races with concurrent process_inputs on the mm_processor_cache."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment