Unverified Commit 6abb0454 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Perf] Optimize the performance of structured output + reasoning (#33557)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent db6f71d4
......@@ -72,6 +72,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
......@@ -132,7 +133,7 @@ class OpenAIServingChat(OpenAIServing):
self.logits_processors = self.model_config.logits_processors
# set up reasoning parser
self.reasoning_parser = ParserManager.get_reasoning_parser(
self.reasoning_parser_cls = ParserManager.get_reasoning_parser(
reasoning_parser_name=reasoning_parser
)
# set up tool use
......@@ -330,6 +331,24 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI
Chat Completion API.
"""
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None
try:
if self.reasoning_parser_cls:
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
result = await self.render_chat_request(request)
if isinstance(result, ErrorResponse):
return result
......@@ -427,7 +446,12 @@ class OpenAIServingChat(OpenAIServing):
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
reasoning_ended = None
if reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
)
engine_request.reasoning_ended = reasoning_ended
generator = self.engine_client.generate(
engine_request,
sampling_params,
......@@ -447,10 +471,6 @@ class OpenAIServingChat(OpenAIServing):
assert len(generators) == 1
(result_generator,) = generators
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream:
return self.chat_completion_stream_generator(
request,
......@@ -460,6 +480,7 @@ class OpenAIServingChat(OpenAIServing):
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
try:
......@@ -471,6 +492,7 @@ class OpenAIServingChat(OpenAIServing):
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
......@@ -630,6 +652,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
......@@ -673,7 +696,7 @@ class OpenAIServingChat(OpenAIServing):
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or self.reasoning_parser:
if tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case
all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled
......@@ -683,28 +706,6 @@ class OpenAIServingChat(OpenAIServing):
else:
all_previous_token_ids = None
try:
if self.reasoning_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
......@@ -826,7 +827,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = tool_parsers[i]
if (
self.reasoning_parser
reasoning_parser
and res.prompt_token_ids
and prompt_is_reasoning_end_arr[i] is None
):
......@@ -888,7 +889,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: DeltaMessage | None
# just update previous_texts and previous_token_ids
if tool_choice_auto or self.reasoning_parser:
if tool_choice_auto or reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
......@@ -915,7 +916,7 @@ class OpenAIServingChat(OpenAIServing):
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
if (
self.reasoning_parser
reasoning_parser
and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end(
previous_token_ids
......@@ -952,7 +953,7 @@ class OpenAIServingChat(OpenAIServing):
current_text = ""
else:
# Just to add remaining `content`
if self.reasoning_parser:
if reasoning_parser:
delta_text = previous_text + delta_text
current_text = ""
......@@ -998,13 +999,13 @@ class OpenAIServingChat(OpenAIServing):
output_token_ids = as_list(output.token_ids)
if (
self.reasoning_parser is not None
reasoning_parser is not None
and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i]
):
reasoning_end_arr[i] = True
if self.reasoning_parser and not reasoning_end_arr[i]:
if reasoning_parser and not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
......@@ -1047,9 +1048,8 @@ class OpenAIServingChat(OpenAIServing):
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.reasoning_parser:
elif tool_choice_auto and reasoning_parser:
assert tool_parser is not None
assert reasoning_parser is not None
assert added_content_delta_arr is not None
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
......@@ -1130,7 +1130,7 @@ class OpenAIServingChat(OpenAIServing):
tools_streamed[i] = True
# when only reasoning
elif self.reasoning_parser:
elif reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
......@@ -1144,9 +1144,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration
if (
tool_choice_auto or self.reasoning_parser
) and not self.use_harmony:
if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
......@@ -1400,6 +1398,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
......@@ -1494,25 +1493,7 @@ class OpenAIServingChat(OpenAIServing):
choices.append(choice_data)
continue
if self.reasoning_parser:
try:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
if reasoning_parser:
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
reasoning, content = reasoning_parser.extract_reasoning(
......
......@@ -83,6 +83,8 @@ class EngineCoreRequest(
# Used in outputs and to support abort(req_id, internal=False).
external_req_id: str | None = None
reasoning_ended: bool | None = None
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
......
......@@ -74,6 +74,7 @@ class Request:
trace_headers: Mapping[str, str] | None = None,
block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
resumable: bool = False,
reasoning_ended: bool | None = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
......@@ -86,6 +87,8 @@ class Request:
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params
)
if self.structured_output_request is not None:
self.structured_output_request.reasoning_ended = reasoning_ended
self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.status = RequestStatus.WAITING
......@@ -195,6 +198,7 @@ class Request:
trace_headers=request.trace_headers,
block_hasher=block_hasher,
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
)
def append_output_token_ids(
......
......@@ -284,12 +284,15 @@ class StructuredOutputManager:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling.
if self.reasoner is not None:
if self.enable_in_reasoning:
return True
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
# This should be removed here, but since `openai_gptoss`
# is an independent code path, it is kept for now.
# After unifying the `openai_gptoss` and non-`openai_gptoss` styles,
# it can be removed.
request.structured_output_request.reasoning_ended = (
self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment