Unverified Commit 0976711f authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Refactor] to simplify and extract the shared logic between chat completion and responses (#27961)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent e261d37c
...@@ -13,7 +13,6 @@ import partial_json_parser ...@@ -13,7 +13,6 @@ import partial_json_parser
import regex as re import regex as re
from fastapi import Request from fastapi import Request
from openai_harmony import Message as OpenAIMessage from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
...@@ -47,8 +46,6 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -47,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (
DeltaMessage, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ErrorResponse, ErrorResponse,
FunctionCall,
FunctionDefinition,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
ToolCall, ToolCall,
...@@ -1394,6 +1391,16 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1394,6 +1391,16 @@ class OpenAIServingChat(OpenAIServing):
auto_tools_called = False auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
tool_calls, content = self._parse_tool_calls_from_content(
request=request,
tokenizer=tokenizer,
content=content,
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and ( if (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
...@@ -1407,63 +1414,33 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1407,63 +1414,33 @@ class OpenAIServingChat(OpenAIServing):
request.tool_choice request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
): ):
tool_call_class = ( assert tool_calls is not None and len(tool_calls) > 0
MistralToolCall
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content="", content="",
tool_calls=[ tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
tool_call_class(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=content,
)
)
],
) )
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class = ( tool_call_class_items = []
MistralToolCall assert tool_calls is not None and len(tool_calls) > 0
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
content
)
tool_call_ids = []
for tool_call in tool_calls: for tool_call in tool_calls:
tool_call_ids.append( tool_call_class_items.append(
make_tool_call_id( tool_call_class(
id=make_tool_call_id(
id_type=self.tool_call_id_type, id_type=self.tool_call_id_type,
func_name=tool_call.name, func_name=tool_call.name,
idx=history_tool_call_cnt, idx=history_tool_call_cnt,
),
function=tool_call,
) )
) )
history_tool_call_cnt += 1 history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
content="", content="",
tool_calls=[ tool_calls=tool_call_class_items,
tool_call_class(
id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(
tool_call.parameters, ensure_ascii=False
),
),
)
for i, tool_call in enumerate(tool_calls)
],
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
) )
...@@ -1481,25 +1458,22 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1481,25 +1458,22 @@ class OpenAIServingChat(OpenAIServing):
and self.enable_auto_tools and self.enable_auto_tools
and self.tool_parser and self.tool_parser
): ):
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "", request=request
)
# In the OpenAI API the finish_reason is "tools_called" # In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool # if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_call_info.tools_called: if tool_calls:
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=tool_call_info.content, content=content,
tool_calls=tool_call_info.tool_calls, tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
) )
else: else:
...@@ -1509,8 +1483,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1509,8 +1483,8 @@ class OpenAIServingChat(OpenAIServing):
# try to use content return from tool parser first, # try to use content return from tool parser first,
# tool parser may do some modify for the content. # tool parser may do some modify for the content.
if tool_call_info.content and len(tool_call_info.content) > 0: if content and len(content) > 0:
ret_content = tool_call_info.content ret_content = content
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
......
...@@ -12,7 +12,7 @@ from typing import Any, ClassVar, Generic, TypeAlias, TypeVar ...@@ -12,7 +12,7 @@ from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import torch import torch
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
...@@ -21,6 +21,10 @@ if sys.version_info >= (3, 12): ...@@ -21,6 +21,10 @@ if sys.version_info >= (3, 12):
else: else:
from typing_extensions import TypedDict from typing_extensions import TypedDict
from openai.types.responses import (
ToolChoiceFunction,
)
import vllm.envs as envs import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -36,6 +40,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -36,6 +40,7 @@ from vllm.entrypoints.chat_utils import (
from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ClassificationRequest, ClassificationRequest,
...@@ -49,6 +54,8 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -49,6 +54,8 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse, EmbeddingResponse,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
FunctionCall,
FunctionDefinition,
IOProcessorRequest, IOProcessorRequest,
PoolingResponse, PoolingResponse,
RerankRequest, RerankRequest,
...@@ -1305,6 +1312,75 @@ class OpenAIServing: ...@@ -1305,6 +1312,75 @@ class OpenAIServing:
except ValueError: except ValueError:
return None return None
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.function.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice == "required":
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
function_calls.extend(
[
FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
)
for tool_call in tool_calls
]
)
content = None # Clear content since tool is called.
elif (
tool_parser_cls
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
):
# Automatic Tool Call Parsing
try:
tool_parser = tool_parser_cls(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
raise e
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "",
request=request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
else:
# No tool calls.
return None, content
return function_calls, content
@staticmethod @staticmethod
def _get_decoded_token( def _get_decoded_token(
logprob: Logprob, logprob: Logprob,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment