Unverified Commit 6abb0454 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Perf] Optimize the performance of structured output + reasoning (#33557)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent db6f71d4
...@@ -72,6 +72,7 @@ from vllm.logger import init_logger ...@@ -72,6 +72,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import ( from vllm.tokenizers.mistral import (
...@@ -132,7 +133,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -132,7 +133,7 @@ class OpenAIServingChat(OpenAIServing):
self.logits_processors = self.model_config.logits_processors self.logits_processors = self.model_config.logits_processors
# set up reasoning parser # set up reasoning parser
self.reasoning_parser = ParserManager.get_reasoning_parser( self.reasoning_parser_cls = ParserManager.get_reasoning_parser(
reasoning_parser_name=reasoning_parser reasoning_parser_name=reasoning_parser
) )
# set up tool use # set up tool use
...@@ -330,6 +331,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -330,6 +331,24 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI for the API specification. This API mimics the OpenAI
Chat Completion API. Chat Completion API.
""" """
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None
try:
if self.reasoning_parser_cls:
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
result = await self.render_chat_request(request) result = await self.render_chat_request(request)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
...@@ -427,7 +446,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -427,7 +446,12 @@ class OpenAIServingChat(OpenAIServing):
priority=request.priority, priority=request.priority,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
reasoning_ended = None
if reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
)
engine_request.reasoning_ended = reasoning_ended
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_request, engine_request,
sampling_params, sampling_params,
...@@ -447,10 +471,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -447,10 +471,6 @@ class OpenAIServingChat(OpenAIServing):
assert len(generators) == 1 assert len(generators) == 1
(result_generator,) = generators (result_generator,) = generators
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, request,
...@@ -460,6 +480,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -460,6 +480,7 @@ class OpenAIServingChat(OpenAIServing):
conversation, conversation,
tokenizer, tokenizer,
request_metadata, request_metadata,
reasoning_parser,
) )
try: try:
...@@ -471,6 +492,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -471,6 +492,7 @@ class OpenAIServingChat(OpenAIServing):
conversation, conversation,
tokenizer, tokenizer,
request_metadata, request_metadata,
reasoning_parser,
) )
except GenerationError as e: except GenerationError as e:
return self._convert_generation_error_to_response(e) return self._convert_generation_error_to_response(e)
...@@ -630,6 +652,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -630,6 +652,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -673,7 +696,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -673,7 +696,7 @@ class OpenAIServingChat(OpenAIServing):
# Only one of these will be used, thus previous_texts and # Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration. # all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or self.reasoning_parser: if tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case # These are only required in "auto" tool choice case
all_previous_token_ids = [[]] * num_choices all_previous_token_ids = [[]] * num_choices
# For reasoning parser and tool call all enabled # For reasoning parser and tool call all enabled
...@@ -683,28 +706,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -683,28 +706,6 @@ class OpenAIServingChat(OpenAIServing):
else: else:
all_previous_token_ids = None all_previous_token_ids = None
try:
if self.reasoning_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
# Prepare the tool parser if it's needed # Prepare the tool parser if it's needed
try: try:
if tool_choice_auto and self.tool_parser: if tool_choice_auto and self.tool_parser:
...@@ -826,7 +827,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -826,7 +827,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = tool_parsers[i] tool_parser = tool_parsers[i]
if ( if (
self.reasoning_parser reasoning_parser
and res.prompt_token_ids and res.prompt_token_ids
and prompt_is_reasoning_end_arr[i] is None and prompt_is_reasoning_end_arr[i] is None
): ):
...@@ -888,7 +889,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -888,7 +889,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: DeltaMessage | None delta_message: DeltaMessage | None
# just update previous_texts and previous_token_ids # just update previous_texts and previous_token_ids
if tool_choice_auto or self.reasoning_parser: if tool_choice_auto or reasoning_parser:
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_text = previous_texts[i] previous_text = previous_texts[i]
...@@ -915,7 +916,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -915,7 +916,7 @@ class OpenAIServingChat(OpenAIServing):
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name: elif tool_choice_function_name:
if ( if (
self.reasoning_parser reasoning_parser
and not reasoning_end_arr[i] and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end( and not reasoning_parser.is_reasoning_end(
previous_token_ids previous_token_ids
...@@ -952,7 +953,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -952,7 +953,7 @@ class OpenAIServingChat(OpenAIServing):
current_text = "" current_text = ""
else: else:
# Just to add remaining `content` # Just to add remaining `content`
if self.reasoning_parser: if reasoning_parser:
delta_text = previous_text + delta_text delta_text = previous_text + delta_text
current_text = "" current_text = ""
...@@ -998,13 +999,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -998,13 +999,13 @@ class OpenAIServingChat(OpenAIServing):
output_token_ids = as_list(output.token_ids) output_token_ids = as_list(output.token_ids)
if ( if (
self.reasoning_parser is not None reasoning_parser is not None
and not reasoning_end_arr[i] and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i] and prompt_is_reasoning_end_arr[i]
): ):
reasoning_end_arr[i] = True reasoning_end_arr[i] = True
if self.reasoning_parser and not reasoning_end_arr[i]: if reasoning_parser and not reasoning_end_arr[i]:
delta_message = ( delta_message = (
reasoning_parser.extract_reasoning_streaming( reasoning_parser.extract_reasoning_streaming(
previous_text, previous_text,
...@@ -1047,9 +1048,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1047,9 +1048,8 @@ class OpenAIServingChat(OpenAIServing):
# handle streaming deltas for tools with "auto" tool choice # handle streaming deltas for tools with "auto" tool choice
# and reasoning parser # and reasoning parser
elif tool_choice_auto and self.reasoning_parser: elif tool_choice_auto and reasoning_parser:
assert tool_parser is not None assert tool_parser is not None
assert reasoning_parser is not None
assert added_content_delta_arr is not None assert added_content_delta_arr is not None
assert reasoning_end_arr is not None assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids) output_token_ids = as_list(output.token_ids)
...@@ -1130,7 +1130,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1130,7 +1130,7 @@ class OpenAIServingChat(OpenAIServing):
tools_streamed[i] = True tools_streamed[i] = True
# when only reasoning # when only reasoning
elif self.reasoning_parser: elif reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming( delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text, previous_text,
current_text, current_text,
...@@ -1144,9 +1144,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1144,9 +1144,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration # update the previous values for the next iteration
if ( if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
tool_choice_auto or self.reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_texts[i] = current_text previous_texts[i] = current_text
...@@ -1400,6 +1398,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1400,6 +1398,7 @@ class OpenAIServingChat(OpenAIServing):
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
...@@ -1494,25 +1493,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1494,25 +1493,7 @@ class OpenAIServingChat(OpenAIServing):
choices.append(choice_data) choices.append(choice_data)
continue continue
if self.reasoning_parser: if reasoning_parser:
try:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
# If the reasoning parser is enabled, # If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content. # tool calls are extracted exclusively from the content.
reasoning, content = reasoning_parser.extract_reasoning( reasoning, content = reasoning_parser.extract_reasoning(
......
...@@ -83,6 +83,8 @@ class EngineCoreRequest( ...@@ -83,6 +83,8 @@ class EngineCoreRequest(
# Used in outputs and to support abort(req_id, internal=False). # Used in outputs and to support abort(req_id, internal=False).
external_req_id: str | None = None external_req_id: str | None = None
reasoning_ended: bool | None = None
@property @property
def params(self) -> SamplingParams | PoolingParams: def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling).""" """Return the processed params (sampling or pooling)."""
......
...@@ -74,6 +74,7 @@ class Request: ...@@ -74,6 +74,7 @@ class Request:
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, block_hasher: Callable[["Request"], list["BlockHash"]] | None = None,
resumable: bool = False, resumable: bool = False,
reasoning_ended: bool | None = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index self.client_index = client_index
...@@ -86,6 +87,8 @@ class Request: ...@@ -86,6 +87,8 @@ class Request:
self.structured_output_request = StructuredOutputRequest.from_sampling_params( self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params sampling_params
) )
if self.structured_output_request is not None:
self.structured_output_request.reasoning_ended = reasoning_ended
self.arrival_time = arrival_time if arrival_time is not None else time.time() self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.status = RequestStatus.WAITING self.status = RequestStatus.WAITING
...@@ -195,6 +198,7 @@ class Request: ...@@ -195,6 +198,7 @@ class Request:
trace_headers=request.trace_headers, trace_headers=request.trace_headers,
block_hasher=block_hasher, block_hasher=block_hasher,
resumable=request.resumable, resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
) )
def append_output_token_ids( def append_output_token_ids(
......
...@@ -284,12 +284,15 @@ class StructuredOutputManager: ...@@ -284,12 +284,15 @@ class StructuredOutputManager:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that # NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always # the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling. # enable the bitmask filling.
if self.reasoner is not None: if self.reasoner is not None:
if self.enable_in_reasoning: if self.enable_in_reasoning:
return True return True
assert request.structured_output_request is not None assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None: if request.structured_output_request.reasoning_ended is None:
# This should be removed here, but since `openai_gptoss`
# is an independent code path, it is kept for now.
# After unifying the `openai_gptoss` and non-`openai_gptoss` styles,
# it can be removed.
request.structured_output_request.reasoning_ended = ( request.structured_output_request.reasoning_ended = (
self.reasoner.is_reasoning_end(request.prompt_token_ids or []) self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment