Unverified Commit 1dab4d57 authored by Will Eaton's avatar Will Eaton Committed by GitHub
Browse files

Tool parser regex timeout handling (#18960)


Signed-off-by: default avatarWill Eaton <weaton@redhat.com>
parent 7f21e805
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps(): ...@@ -191,3 +191,27 @@ def test_streaming_tool_call_with_large_steps():
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
content, tool_calls = run_tool_extraction(tool_parser,
fake_problematic_input,
streaming=streaming)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps(): ...@@ -159,3 +159,27 @@ def test_streaming_tool_call_with_large_steps():
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
content, tool_calls = run_tool_extraction(tool_parser,
fake_problematic_input,
streaming=streaming)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()
...@@ -7,6 +7,7 @@ from typing import Any, Union ...@@ -7,6 +7,7 @@ from typing import Any, Union
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -64,7 +65,19 @@ class Llama4PythonicToolParser(ToolParser):
if model_output.startswith("<|python_start|>"): if model_output.startswith("<|python_start|>"):
model_output = model_output[len("<|python_start|>"):] model_output = model_output[len("<|python_start|>"):]
model_output = model_output.replace("<|python_end|>", "") model_output = model_output.replace("<|python_end|>", "")
if not (self.TOOL_CALL_REGEX.match(model_output)):
is_tool_call_pattern = False
try:
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
model_output,
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
except TimeoutError:
logger.warning(
"Regex timeout occurred when matching tool call pattern.")
logger.debug("Regex timeout occurred when matching user input: %s",
model_output)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(tools_called=False, return ExtractedToolCallInformation(tools_called=False,
tool_calls=[], tool_calls=[],
content=model_output) content=model_output)
......
...@@ -8,6 +8,7 @@ from typing import Any, Union ...@@ -8,6 +8,7 @@ from typing import Any, Union
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser): ...@@ -61,8 +62,18 @@ class PythonicToolParser(ToolParser):
""" """
Extract the tool calls from a complete model response. Extract the tool calls from a complete model response.
""" """
is_tool_call_pattern = False
if not (self.TOOL_CALL_REGEX.match(model_output)): try:
is_tool_call_pattern = self.TOOL_CALL_REGEX.match(
model_output,
timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None
except TimeoutError:
logger.warning(
"Regex timeout occurred when matching tool call pattern.")
logger.debug("Regex timeout occurred when matching user input: %s",
model_output)
if not is_tool_call_pattern:
return ExtractedToolCallInformation(tools_called=False, return ExtractedToolCallInformation(tools_called=False,
tool_calls=[], tool_calls=[],
content=model_output) content=model_output)
......
...@@ -119,6 +119,7 @@ if TYPE_CHECKING: ...@@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
def get_default_cache_root(): def get_default_cache_root():
...@@ -828,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -828,6 +829,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# This is used to prevent the kernel from running out of memory. # This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
# Regex timeout for use by the vLLM tool parsing plugins.
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment