Unverified Commit da7bc54e authored by Andrew Xia's avatar Andrew Xia Committed by GitHub
Browse files

[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)


Signed-off-by: default avatarAndrew Xia <axia@fb.com>
Signed-off-by: default avatarAndrew Xia <axia@meta.com>
Co-authored-by: default avatarAndrew Xia <axia@fb.com>
parent 949a6a19
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import pytest
import pytest_asyncio
......@@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module")
def server():
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)
args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
# uncomment for tool calling
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
)
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
......@@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
args = json.loads(function_call.arguments)
assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text
......@@ -9,10 +9,16 @@ from collections.abc import Callable
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
......@@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import (
get_responses_parser_for_simple_context,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.entrypoints.responses_utils import construct_tool_dicts
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
if TYPE_CHECKING:
from mcp.client import ClientSession
......@@ -221,6 +231,10 @@ class ParsableContext(ConversationContext):
tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
self.num_prompt_tokens = 0
self.num_output_tokens = 0
......@@ -238,12 +252,19 @@ class ParsableContext(ConversationContext):
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
......@@ -252,14 +273,50 @@ class ParsableContext(ConversationContext):
self.parser.process(output.outputs[0])
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
raise NotImplementedError("Should not be called.")
self.parser.response_messages.extend(output)
def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call"""
last_message = self.parser.response_messages[-1]
# TODO: figure out which tools are MCP tools
if ( # noqa: SIM103
last_message.type == "function_call"
and last_message.name in ("code_interpreter", "python")
):
return True
return False
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_tool(self) -> list[ResponseInputOutputItem]:
raise NotImplementedError("Should not be called.")
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
return []
def render_for_completion(self):
raise NotImplementedError("Should not be called.")
......@@ -271,11 +328,38 @@ class ParsableContext(ConversationContext):
request_id: str,
mcp_tools: dict[str, Mcp],
):
pass
if tool_server:
for tool_name in self.available_tools:
if tool_name in self._tool_sessions:
continue
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
raise NotImplementedError("Should not be called.")
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class HarmonyContext(ConversationContext):
......
......@@ -3,6 +3,7 @@
import logging
from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
......@@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
......@@ -29,6 +32,7 @@ class ResponsesParser:
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
......@@ -39,6 +43,9 @@ class ResponsesParser:
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
......@@ -59,6 +66,29 @@ class ResponsesParser:
)
)
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content:
self.response_messages.append(
ResponseOutputMessage(
......@@ -76,6 +106,8 @@ class ResponsesParser:
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self
......@@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
......@@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
......@@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
......@@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest,
ScoreResponse,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
......@@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
......@@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
......@@ -1224,6 +1236,31 @@ class OpenAIServing:
)
return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: AnyTokenizer,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
new_messages,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return request_prompts, engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
......@@ -1286,11 +1323,27 @@ class OpenAIServing:
# Create inputs for the next turn.
# Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
......
......@@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None:
if self.tool_server is not None:
if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
......@@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
else:
context = SimpleContext()
......
......@@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
......@@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
)
from vllm.utils import random_uuid
def make_response_output_items_from_parsable_context(
......@@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context(
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
raise NotImplementedError("tool calls not supported for response context")
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"mcp_{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai_harmony import Author, Message, Role, TextContent
from vllm.logger import init_logger
from vllm.utils import random_uuid
if TYPE_CHECKING:
# Avoid circular import.
......@@ -46,6 +51,10 @@ class Tool(ABC):
async def get_result(self, context: "ConversationContext") -> Any:
pass
@abstractmethod
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool):
def __init__(self):
......@@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
tool_output_msgs.append(msg)
return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
raise NotImplementedError("Not implemented yet")
@property
def tool_config(self) -> Any:
return self.browser_tool.tool_config
......@@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
tool_output_msgs.append(msg)
return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
"""
This function converts parsable context types to harmony and
back so we can use GPTOSS demo python tool
"""
from vllm.entrypoints.context import ParsableContext
assert isinstance(context, ParsableContext)
last_msg = context.parser.response_messages[-1]
args = json.loads(last_msg.arguments)
last_msg_harmony = Message(
author=Author(role="assistant", name=None),
content=[TextContent(text=args["code"])],
channel="analysis",
recipient="python",
content_type="code",
)
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg_harmony):
processed = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=msg.content[0].text,
status="completed",
)
tool_output_msgs.append(processed)
return tool_output_msgs
@property
def tool_config(self) -> Any:
return self.python_tool.tool_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment