Unverified Commit 5a2d420c authored by Sergey Zinchenko's avatar Sergey Zinchenko Committed by GitHub
Browse files

[Bugfix] Use dedicated MM processor cache in /tokenize to prevent sender-cache pollution (#38545)


Signed-off-by: default avatarSergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
parent 5f96f9af
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Regression test: calling ``/tokenize`` with multimodal data followed by
``/v1/chat/completions`` with the same data must not cause an error.
Ensures that the ``/tokenize`` endpoint does not pollute internal caches
(e.g. multimodal feature caches) and that a subsequent
``/v1/chat/completions`` request with the same multimodal payload
completes successfully.
"""
import json
import openai
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--max-num-seqs",
"5",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"image": 1}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_tokenize_then_chat_completion_with_image(
client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
local_asset_server,
):
"""Tokenize a multimodal message, then send the same message to chat
completions. The chat completion must succeed (not 500)."""
image_url = local_asset_server.url_for("stop_sign.jpg")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Describe this image briefly."},
],
}
]
tok_resp = requests.post(
server.url_for("tokenize"),
json={"model": MODEL_NAME, "messages": messages},
)
tok_resp.raise_for_status()
tok_data = tok_resp.json()
assert tok_data["count"] > 0, "Tokenization must return tokens"
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
temperature=0.0,
)
assert chat_completion.choices[0].message.content, (
"Chat completion must produce non-empty content after tokenize"
)
...@@ -451,6 +451,8 @@ class OpenAIServingRender: ...@@ -451,6 +451,8 @@ class OpenAIServingRender:
request: Any, request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
*,
skip_mm_cache: bool = False,
) -> list[EngineInput]: ) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_completion.""" """Copied from OpenAIServing._preprocess_completion."""
prompts = list[SingletonPrompt | bytes]() prompts = list[SingletonPrompt | bytes]()
...@@ -458,12 +460,14 @@ class OpenAIServingRender: ...@@ -458,12 +460,14 @@ class OpenAIServingRender:
prompts.extend(prompt_to_seq(prompt_embeds)) prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None: if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input)) prompts.extend(prompt_to_seq(prompt_input))
return await self.preprocess_cmpl(request, prompts) return await self.preprocess_cmpl(request, prompts, skip_mm_cache=skip_mm_cache)
async def preprocess_cmpl( async def preprocess_cmpl(
self, self,
request: Any, request: Any,
prompts: Sequence[PromptType | bytes], prompts: Sequence[PromptType | bytes],
*,
skip_mm_cache: bool = False,
) -> list[EngineInput]: ) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_cmpl.""" """Copied from OpenAIServing._preprocess_cmpl."""
renderer = self.renderer renderer = self.renderer
...@@ -487,6 +491,7 @@ class OpenAIServingRender: ...@@ -487,6 +491,7 @@ class OpenAIServingRender:
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
}, },
skip_mm_cache=skip_mm_cache,
) )
async def preprocess_chat( async def preprocess_chat(
...@@ -498,6 +503,8 @@ class OpenAIServingRender: ...@@ -498,6 +503,8 @@ class OpenAIServingRender:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None, tool_parser: type[ToolParser] | None = None,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]]: ) -> tuple[list[ConversationMessage], list[EngineInput]]:
"""Copied from OpenAIServing._preprocess_chat.""" """Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer renderer = self.renderer
...@@ -529,6 +536,7 @@ class OpenAIServingRender: ...@@ -529,6 +536,7 @@ class OpenAIServingRender:
for k in ("mm_processor_kwargs", "cache_salt") for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None if (v := getattr(request, k, None)) is not None
}, },
skip_mm_cache=skip_mm_cache,
) )
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
......
...@@ -86,12 +86,14 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -86,12 +86,14 @@ class OpenAIServingTokenization(OpenAIServing):
default_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs, default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
skip_mm_cache=True,
) )
else: else:
engine_inputs = await self.openai_serving_render.preprocess_completion( engine_inputs = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=None, prompt_embeds=None,
skip_mm_cache=True,
) )
input_ids: list[int] = [] input_ids: list[int] = []
......
...@@ -97,6 +97,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -97,6 +97,7 @@ class BaseRenderer(ABC, Generic[_T]):
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
self.mm_processor: BaseMultiModalProcessor | None = None self.mm_processor: BaseMultiModalProcessor | None = None
self._readonly_mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None self._mm_cache_stats: MultiModalCacheStats | None = None
self._clear_mm_cache_async = make_async( self._clear_mm_cache_async = make_async(
self.clear_mm_cache, executor=self._executor self.clear_mm_cache, executor=self._executor
...@@ -124,6 +125,19 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -124,6 +125,19 @@ class BaseRenderer(ABC, Generic[_T]):
if mm_processor_cache: if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats() self._mm_cache_stats = MultiModalCacheStats()
# A second processor with its own processor-only cache.
# Used by the tokenize endpoint so that tokenize-only
# requests don't pollute the sender cache.
ro_cache = mm_registry.processor_only_cache_from_config(config)
if ro_cache is not None:
ro_tokenizer = copy.deepcopy(tokenizer)
with set_default_torch_num_threads():
self._readonly_mm_processor = mm_registry.create_processor(
config.model_config,
tokenizer=ro_tokenizer,
cache=ro_cache,
)
# This is used to generate internal request ID for MM processing # This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core # It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter() self._mm_req_counter = AtomicCounter()
...@@ -625,9 +639,14 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -625,9 +639,14 @@ class BaseRenderer(ABC, Generic[_T]):
mm_uuids: MultiModalUUIDDict | None, mm_uuids: MultiModalUUIDDict | None,
mm_processor_kwargs: Mapping[str, object] | None, mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None, tokenization_kwargs: dict[str, Any] | None,
*,
skip_mm_cache: bool = False,
) -> "MultiModalInput": ) -> "MultiModalInput":
mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}" mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}"
if skip_mm_cache and self._readonly_mm_processor is not None:
mm_processor = self._readonly_mm_processor
else:
mm_processor = self.get_mm_processor() mm_processor = self.get_mm_processor()
mm_data_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
...@@ -656,6 +675,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -656,6 +675,8 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_tokens( def _process_tokens(
self, self,
prompt: TokensPrompt, prompt: TokensPrompt,
*,
skip_mm_cache: bool = False,
) -> TokensInput | MultiModalInput: ) -> TokensInput | MultiModalInput:
"""Process token inputs, with multimodal preprocessing offloaded """Process token inputs, with multimodal preprocessing offloaded
to the shared thread pool in the async variant. to the shared thread pool in the async variant.
...@@ -670,6 +691,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -670,6 +691,7 @@ class BaseRenderer(ABC, Generic[_T]):
mm_processor_kwargs=prompt.get("mm_processor_kwargs"), mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, # Tokenization already done in Step 2 tokenization_kwargs=None, # Tokenization already done in Step 2
mm_uuids=prompt.get("multi_modal_uuids"), mm_uuids=prompt.get("multi_modal_uuids"),
skip_mm_cache=skip_mm_cache,
) )
else: else:
engine_input = tokens_input(prompt_token_ids) engine_input = tokens_input(prompt_token_ids)
...@@ -712,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -712,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
async def _process_tokens_async( async def _process_tokens_async(
self, self,
prompt: TokensPrompt, prompt: TokensPrompt,
*,
skip_mm_cache: bool = False,
) -> TokensInput | MultiModalInput: ) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
...@@ -723,6 +747,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -723,6 +747,7 @@ class BaseRenderer(ABC, Generic[_T]):
mm_processor_kwargs=prompt.get("mm_processor_kwargs"), mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None, tokenization_kwargs=None,
mm_uuids=prompt.get("multi_modal_uuids"), mm_uuids=prompt.get("multi_modal_uuids"),
skip_mm_cache=skip_mm_cache,
) )
else: else:
engine_input = tokens_input(prompt_token_ids) engine_input = tokens_input(prompt_token_ids)
...@@ -734,24 +759,33 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -734,24 +759,33 @@ class BaseRenderer(ABC, Generic[_T]):
return engine_input return engine_input
def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput: def _process_singleton(
self,
prompt: SingletonTokPrompt,
*,
skip_mm_cache: bool = False,
) -> SingletonInput:
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type] return self._process_embeds(prompt) # type: ignore[arg-type]
return self._process_tokens(prompt) # type: ignore[arg-type] return self._process_tokens(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
async def _process_singleton_async( async def _process_singleton_async(
self, self,
prompt: SingletonTokPrompt, prompt: SingletonTokPrompt,
*,
skip_mm_cache: bool = False,
) -> SingletonInput: ) -> SingletonInput:
if "prompt_embeds" in prompt: if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type] return self._process_embeds(prompt) # type: ignore[arg-type]
return await self._process_tokens_async(prompt) # type: ignore[arg-type] return await self._process_tokens_async(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
def _process_enc_dec( def _process_enc_dec(
self, self,
prompt: EncoderDecoderTokPrompt, prompt: EncoderDecoderTokPrompt,
*,
skip_mm_cache: bool = False,
) -> EncoderDecoderInput: ) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"] enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
...@@ -764,9 +798,13 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -764,9 +798,13 @@ class BaseRenderer(ABC, Generic[_T]):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_input( return build_enc_dec_input(
encoder_input=self._process_singleton(enc_prompt), encoder_input=self._process_singleton(
enc_prompt, skip_mm_cache=skip_mm_cache
),
decoder_input=( decoder_input=(
None if dec_prompt is None else self._process_singleton(dec_prompt) None
if dec_prompt is None
else self._process_singleton(dec_prompt, skip_mm_cache=skip_mm_cache)
), ),
decoder_start_token_id=self.get_dec_start_token_id(), decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token, skip_decoder_start_token=skip_decoder_start_token,
...@@ -775,16 +813,20 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -775,16 +813,20 @@ class BaseRenderer(ABC, Generic[_T]):
async def _process_enc_dec_async( async def _process_enc_dec_async(
self, self,
prompt: EncoderDecoderTokPrompt, prompt: EncoderDecoderTokPrompt,
*,
skip_mm_cache: bool = False,
) -> EncoderDecoderInput: ) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"] enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
encoder_input, decoder_input = await asyncio.gather( encoder_input, decoder_input = await asyncio.gather(
self._process_singleton_async(enc_prompt), self._process_singleton_async(enc_prompt, skip_mm_cache=skip_mm_cache),
( (
asyncio.sleep(0) asyncio.sleep(0)
if dec_prompt is None if dec_prompt is None
else self._process_singleton_async(dec_prompt) else self._process_singleton_async(
dec_prompt, skip_mm_cache=skip_mm_cache
)
), ),
) )
...@@ -794,27 +836,40 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -794,27 +836,40 @@ class BaseRenderer(ABC, Generic[_T]):
decoder_start_token_id=self.get_dec_start_token_id(), decoder_start_token_id=self.get_dec_start_token_id(),
) )
def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput: def process_for_engine(
self,
prompt: TokPrompt,
arrival_time: float,
*,
skip_mm_cache: bool = False,
) -> EngineInput:
engine_input: EngineInput engine_input: EngineInput
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type] engine_input = self._process_enc_dec(prompt, skip_mm_cache=skip_mm_cache) # type: ignore[arg-type]
else: else:
engine_input = self._process_singleton(prompt) engine_input = self._process_singleton(prompt, skip_mm_cache=skip_mm_cache)
engine_input["arrival_time"] = arrival_time engine_input["arrival_time"] = arrival_time
return engine_input return engine_input
async def process_for_engine_async( async def process_for_engine_async(
self, prompt: TokPrompt, arrival_time: float self,
prompt: TokPrompt,
arrival_time: float,
*,
skip_mm_cache: bool = False,
) -> EngineInput: ) -> EngineInput:
engine_input: EngineInput engine_input: EngineInput
if "encoder_prompt" in prompt: if "encoder_prompt" in prompt:
engine_input = await self._process_enc_dec_async( engine_input = await self._process_enc_dec_async(
prompt # type: ignore[arg-type] prompt, # type: ignore[arg-type]
skip_mm_cache=skip_mm_cache,
) )
else: else:
engine_input = await self._process_singleton_async(prompt) engine_input = await self._process_singleton_async(
prompt, skip_mm_cache=skip_mm_cache
)
engine_input["arrival_time"] = arrival_time engine_input["arrival_time"] = arrival_time
...@@ -827,6 +882,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -827,6 +882,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
): ):
arrival_time = time.time() arrival_time = time.time()
...@@ -838,7 +894,10 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -838,7 +894,10 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts] return [
self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
for prompt in tok_prompts
]
async def render_cmpl_async( async def render_cmpl_async(
self, self,
...@@ -846,6 +905,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -846,6 +905,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
): ):
arrival_time = time.time() arrival_time = time.time()
...@@ -858,7 +918,12 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -858,7 +918,12 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
return await asyncio.gather( return await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts) *(
self.process_for_engine_async(
p, arrival_time, skip_mm_cache=skip_mm_cache
)
for p in tok_prompts
)
) )
def render_chat( def render_chat(
...@@ -868,6 +933,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -868,6 +933,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
): ):
arrival_time = time.time() arrival_time = time.time()
...@@ -890,7 +956,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -890,7 +956,8 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = [ eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts self.process_for_engine(prompt, arrival_time, skip_mm_cache=skip_mm_cache)
for prompt in tok_prompts
] ]
return out_conversations, eng_prompts return out_conversations, eng_prompts
...@@ -902,6 +969,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -902,6 +969,7 @@ class BaseRenderer(ABC, Generic[_T]):
tok_params: TokenizeParams | None = None, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
skip_mm_cache: bool = False,
): ):
arrival_time = time.time() arrival_time = time.time()
...@@ -924,7 +992,12 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -924,7 +992,12 @@ class BaseRenderer(ABC, Generic[_T]):
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
eng_prompts = await asyncio.gather( eng_prompts = await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts) *(
self.process_for_engine_async(
p, arrival_time, skip_mm_cache=skip_mm_cache
)
for p in tok_prompts
)
) )
return out_conversations, eng_prompts return out_conversations, eng_prompts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment