Unverified Commit e7261315 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

bugfix(tool call ebnf): Fix EBNF generation for optional function parameters (#7283)

parent 8c16da33
...@@ -211,20 +211,74 @@ class EBNFComposer: ...@@ -211,20 +211,74 @@ class EBNFComposer:
properties = params.get("properties", {}) properties = params.get("properties", {})
required_props = set(params.get("required", [])) required_props = set(params.get("required", []))
# Build argument rules for this tool # The generated pattern ensures:
arg_rules = [] # 1. Required properties appear first, joined by commas
# 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
# 3. For multiple optional properties, we allow flexible ordering:
# - Each optional can be skipped entirely
# - They can appear in any combination
#
# Example patterns generated:
# - One required, one optional:
# "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
# Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
#
# - Multiple optional properties with flexible ordering:
# "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
# Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
# {"req": "x", "opt1": "y", "opt2": "z"}
#
# - All optional properties with flexible ordering:
# "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
# Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
prop_kv_pairs = {}
ordered_props = list(properties.keys())
for prop_name, prop_schema in properties.items(): for prop_name, prop_schema in properties.items():
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format) value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
# Create key=value pair # Create key=value pair
pair = key_value_template.format(key=prop_name, valrule=value_rule) pair = key_value_template.format(key=prop_name, valrule=value_rule)
prop_kv_pairs[prop_name] = pair
# Separate into required and optional while preserving order
required = [p for p in ordered_props if p in required_props]
optional = [p for p in ordered_props if p not in required_props]
# Build the combined rule
rule_parts = []
# Add required properties joined by commas
if required:
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
# Add optional properties with flexible ordering
if optional:
# Build alternatives where any optional property can appear first
opt_alternatives = []
for i in range(len(optional)):
# Build pattern for optional[i] appearing first
opt_parts = []
for j in range(i, len(optional)):
if j == i:
opt_parts.append(prop_kv_pairs[optional[j]])
else:
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
opt_alternatives.append("".join(opt_parts))
# Wrap with appropriate comma handling based on whether we have required properties
if required:
# Required properties exist, so optional group needs outer comma
rule_parts.append(' ( "," ( ')
rule_parts.append(" | ".join(opt_alternatives))
rule_parts.append(" ) )?")
else:
# All properties are optional
rule_parts.append("( ")
rule_parts.append(" | ".join(opt_alternatives))
rule_parts.append(" )?")
if prop_name not in required_props: combined_args = "".join(rule_parts)
pair = f"[ {pair} ]"
arg_rules.append(pair)
# Combine all argument rules
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
arguments_rule = args_template.format(arg_rules=combined_args) arguments_rule = args_template.format(arg_rules=combined_args)
# Add the function call rule and its arguments rule # Add the function call rule and its arguments rule
......
""" """
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_all_optional_function_params
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_all_optional_function_params
""" """
import json import json
...@@ -237,6 +239,38 @@ class TestEBNFConstrained(CustomTestCase): ...@@ -237,6 +239,38 @@ class TestEBNFConstrained(CustomTestCase):
n=3, n=3,
) )
def test_ebnf_generate_all_optional_function_params(self):
"""Test function call with all optional parameters - verifies flexible ordering."""
self.__class__.ebnf_grammar = """
root ::= function_call
function_call ::= call_config_service
call_config_service ::= "{" "\\"name\\"" ":" "\\"config_service\\"" ", " "\\"arguments\\"" ":" arguments_config_service "}"
arguments_config_service ::= "{" ( "\\"theme\\"" ":" ("\\"light\\"" | "\\"dark\\"") ( "," "\\"language\\"" ":" ("\\"en\\"" | "\\"es\\"" | "\\"fr\\"") )? ( "," "\\"notifications\\"" ":" ("true" | "false") )? | "\\"language\\"" ":" ("\\"en\\"" | "\\"es\\"" | "\\"fr\\"") ( "," "\\"notifications\\"" ":" ("true" | "false") )? | "\\"notifications\\"" ":" ("true" | "false") )? "}"
"""
# Test patterns that should match - flexible ordering of optional parameters
allowed_patterns = [
# Empty arguments
r'^\{"name":"config_service", "arguments":\{\}\}$',
# Single optional parameters (any can appear first)
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)"\}\}$',
r'^\{"name":"config_service", "arguments":\{"language":"(en|es|fr)"\}\}$',
r'^\{"name":"config_service", "arguments":\{"notifications":(true|false)\}\}$',
# Two optional parameters (in any order)
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "language":"(en|es|fr)"\}\}$',
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "notifications":(true|false)\}\}$',
r'^\{"name":"config_service", "arguments":\{"language":"(en|es|fr)", "notifications":(true|false)\}\}$',
# All three optional parameters
r'^\{"name":"config_service", "arguments":\{"theme":"(light|dark)", "language":"(en|es|fr)", "notifications":(true|false)\}\}$',
]
prompt = "Configure the service with optional settings:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=5,
)
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained): class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod @classmethod
......
...@@ -515,7 +515,7 @@ class TestEBNFGeneration(unittest.TestCase): ...@@ -515,7 +515,7 @@ class TestEBNFGeneration(unittest.TestCase):
# Check that the EBNF contains expected patterns # Check that the EBNF contains expected patterns
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf) self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
self.assertIn('"location" "=" basic_string', ebnf) self.assertIn('"location" "=" basic_string', ebnf)
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf) self.assertIn('( "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") )', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler # Validate that the EBNF can be compiled by GrammarCompiler
try: try:
...@@ -591,6 +591,224 @@ class TestEBNFGeneration(unittest.TestCase): ...@@ -591,6 +591,224 @@ class TestEBNFGeneration(unittest.TestCase):
except RuntimeError as e: except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}") self.fail(f"Failed to compile EBNF: {e}")
def test_weather_function_optional_parameter_handling(self):
"""Test that weather function with optional unit parameter generates correct EBNF without trailing commas."""
# Create a weather tool with required location and optional unit
weather_tool = Tool(
type="function",
function=Function(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
),
)
# Test all detectors with the weather tool
detectors = {
"pythonic": self.pythonic_detector,
"deepseekv3": self.deepseekv3_detector,
"llama32": self.llama32_detector,
"mistral": self.mistral_detector,
"qwen25": self.qwen25_detector,
}
for name, detector in detectors.items():
with self.subTest(detector=name):
ebnf = detector.build_ebnf([weather_tool])
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
# Check that the EBNF properly handles optional parameters
if name == "pythonic":
# Pythonic format: location="Paris" ( , ( unit=("celsius" | "fahrenheit") )?
self.assertIn('"location" "=" basic_string', ebnf)
# The comma should be inside the optional brackets for unit
self.assertIn('( "," ( "unit" "=" ', ebnf)
else:
# JSON format: "location": "Paris" ( , ( "unit": ("celsius" | "fahrenheit") )?
self.assertIn('"location\\"" ":" basic_string', ebnf)
# The comma should be part of the optional group
# This pattern ensures no trailing comma when unit is omitted
self.assertIn('( "," ( "\\"unit\\"" ":"', ebnf)
# Validate that the EBNF can be compiled
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(
ctx, f"{name} EBNF should compile successfully"
)
except RuntimeError as e:
self.fail(f"Failed to compile {name} EBNF: {e}")
def test_multiple_optional_parameters_flexible_ordering(self):
"""Test that multiple optional parameters allow flexible ordering using llama.cpp approach."""
# Create a tool with one required and multiple optional parameters
test_tool = Tool(
type="function",
function=Function(
name="test_func",
description="Test function with multiple optional parameters",
parameters={
"type": "object",
"properties": {
"required_field": {"type": "string"},
"opt1": {"type": "number"},
"opt2": {"type": "boolean"},
"opt3": {"type": "string"},
},
"required": ["required_field"],
},
),
)
# Test JSON-based detectors (not pythonic)
json_detectors = {
"deepseekv3": self.deepseekv3_detector,
"llama32": self.llama32_detector,
"mistral": self.mistral_detector,
"qwen25": self.qwen25_detector,
}
for name, detector in json_detectors.items():
with self.subTest(detector=name):
ebnf = detector.build_ebnf([test_tool])
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
# Print the arguments rule for debugging
lines = ebnf.split("\n")
args_rule = None
for line in lines:
if line.startswith("arguments_test_func ::="):
args_rule = line
break
self.assertIsNotNone(
args_rule, f"{name} should have arguments_test_func rule"
)
# Check required field
self.assertIn('"required_field\\"" ":" basic_string', ebnf)
# Check the structure for optional parameters
# The pattern should be: required_field ( "," ( opt1 ... | opt2 ... | opt3 ... ) )?
# This allows flexible ordering where any optional can be first
# Check that optional parameters are in a group with comma
if args_rule: # Only check if args_rule was found
self.assertIn(
'( ","',
args_rule,
f"{name} should have comma grouped with optional parameters",
)
# Check for the alternation pattern that allows flexible ordering
# Should contain patterns like: opt1 ... | opt2 ... | opt3
self.assertIn('"opt1\\"" ":" basic_number', args_rule)
self.assertIn('"opt2\\"" ":" basic_boolean', args_rule)
self.assertIn('"opt3\\"" ":" basic_string', args_rule)
# Check for alternation (|) which allows skipping optional parameters
self.assertIn(
"|",
args_rule,
f"{name} should use alternation for flexible optional ordering",
)
# Check that the pattern ends properly with closing braces
self.assertTrue(
args_rule.endswith('"}"'),
f"{name} arguments rule should end with closing brace",
)
# Validate compilation
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(
ctx, f"{name} EBNF should compile successfully"
)
except RuntimeError as e:
self.fail(f"Failed to compile {name} EBNF: {e}")
def test_all_optional_parameters_ordering(self):
"""Test the behavior when ALL parameters are optional - verifies ordering constraints."""
# Create a tool with only optional parameters
all_optional_tool = Tool(
type="function",
function=Function(
name="optional_func",
description="Function with all optional parameters",
parameters={
"type": "object",
"properties": {
"opt1": {"type": "string"},
"opt2": {"type": "number"},
"opt3": {"type": "boolean"},
},
"required": [], # No required parameters
},
),
)
# Test JSON-based detectors
json_detectors = {
"deepseekv3": self.deepseekv3_detector,
"llama32": self.llama32_detector,
"mistral": self.mistral_detector,
"qwen25": self.qwen25_detector,
}
for name, detector in json_detectors.items():
with self.subTest(detector=name):
ebnf = detector.build_ebnf([all_optional_tool])
self.assertIsNotNone(ebnf, f"{name} detector should generate EBNF")
# Extract the arguments rule
lines = ebnf.split("\n")
args_rule = None
for line in lines:
if line.startswith("arguments_optional_func ::="):
args_rule = line
break
self.assertIsNotNone(
args_rule, f"{name} should have arguments_optional_func rule"
)
if args_rule:
# When all parameters are optional, the pattern now uses alternation:
# "{" ( opt1 ... | opt2 ... | opt3 ... )? "}"
# This allows flexible ordering where any optional can appear first
# Check the structure
self.assertIn('"opt1\\"" ":" basic_string', args_rule)
self.assertIn('"opt2\\"" ":" basic_number', args_rule)
self.assertIn('"opt3\\"" ":" basic_boolean', args_rule)
# The pattern SHOULD have alternation (|) for flexible ordering
self.assertIn(
"|",
args_rule,
f"{name} should use alternation for flexible ordering even when all properties are optional",
)
# Validate compilation
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(
ctx, f"{name} EBNF should compile successfully"
)
except RuntimeError as e:
self.fail(f"Failed to compile {name} EBNF: {e}")
class TestBaseFormatDetector(unittest.TestCase): class TestBaseFormatDetector(unittest.TestCase):
"""Test buffer management and sequential tool index assignment in BaseFormatDetector.""" """Test buffer management and sequential tool index assignment in BaseFormatDetector."""
......
...@@ -77,7 +77,11 @@ class TestToolChoiceLlama32(CustomTestCase): ...@@ -77,7 +77,11 @@ class TestToolChoiceLlama32(CustomTestCase):
"city": { "city": {
"type": "string", "type": "string",
"description": "name of the city to get weather for", "description": "name of the city to get weather for",
} },
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
}, },
"required": ["city"], "required": ["city"],
}, },
...@@ -152,7 +156,7 @@ class TestToolChoiceLlama32(CustomTestCase): ...@@ -152,7 +156,7 @@ class TestToolChoiceLlama32(CustomTestCase):
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
}, },
"required": ["location", "unit"], "required": ["location"],
}, },
}, },
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment