Unverified Commit 6a09612b authored by PlatinumGod's avatar PlatinumGod Committed by GitHub
Browse files

[Bugfix] Fix tool_choice="none" being ignored by GPT-OSS/harmony models (#30867)


Signed-off-by: default avataryujiepu <pyjapple@gmail.com>
Signed-off-by: default avatarPlatinumGod <pyjapple@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 45c0526a
...@@ -52,8 +52,19 @@ def with_tool_parser(request) -> bool: ...@@ -52,8 +52,19 @@ def with_tool_parser(request) -> bool:
return request.param return request.param
@pytest.fixture(
scope="module",
params=[True],
ids=["exclude_tools_when_tool_choice_none"],
)
def exclude_tools_when_tool_choice_none(request) -> bool:
return request.param
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(with_tool_parser: bool): def default_server_args(
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool
):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--enforce-eager", "--enforce-eager",
...@@ -72,6 +83,8 @@ def default_server_args(with_tool_parser: bool): ...@@ -72,6 +83,8 @@ def default_server_args(with_tool_parser: bool):
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
] ]
) )
if exclude_tools_when_tool_choice_none:
args.append("--exclude-tools-when-tool-choice-none")
return args return args
...@@ -335,6 +348,69 @@ async def test_gpt_oss_tool_message_array_content( ...@@ -335,6 +348,69 @@ async def test_gpt_oss_tool_message_array_content(
assert response_multi_array.choices[0].message is not None assert response_multi_array.choices[0].message is not None
@pytest.mark.asyncio
async def test_gpt_oss_tool_choice_none(
gptoss_client: OpenAI,
with_tool_parser: bool,
exclude_tools_when_tool_choice_none: bool,
):
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
pytest.skip(
"skip tool_choice tests when non-tool or "
"--exclude-tools-when-tool-choice-none not set"
)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"state": {"type": "string"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the temperature(in degrees Celsius) in Dallas?",
},
]
tool_choice_auto = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.0,
)
msg = tool_choice_auto.choices[0].message
assert len(msg.tool_calls) == 1
tool_choice_none = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
tool_choice="none",
temperature=0.0,
)
msg = tool_choice_none.choices[0].message
assert len(msg.tool_calls) == 0
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2" MODEL_NAME_SHORT = "gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
......
...@@ -299,7 +299,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -299,7 +299,10 @@ class OpenAIServingChat(OpenAIServing):
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
conversation, engine_prompts = self._make_request_with_harmony(request) should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(f"{e} {e.__cause__}")
...@@ -1833,6 +1836,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1833,6 +1836,7 @@ class OpenAIServingChat(OpenAIServing):
def _make_request_with_harmony( def _make_request_with_harmony(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
should_include_tools: bool = True,
): ):
messages: list[OpenAIMessage] = [] messages: list[OpenAIMessage] = []
...@@ -1850,12 +1854,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1850,12 +1854,14 @@ class OpenAIServingChat(OpenAIServing):
reasoning_effort=request.reasoning_effort, reasoning_effort=request.reasoning_effort,
browser_description=None, browser_description=None,
python_description=None, python_description=None,
with_custom_tools=request.tools is not None, with_custom_tools=should_include_tools,
) )
messages.append(sys_msg) messages.append(sys_msg)
# Add developer message. # Add developer message.
dev_msg = get_developer_message(tools=request.tools) dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None
)
messages.append(dev_msg) messages.append(dev_msg)
# Add user message. # Add user message.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment