"vscode:/vscode.git/clone" did not exist on "6bbee1048bc5519fce91ebd81a592449b4f5f6c0"
Unverified Commit 4fac524b authored by Chao Yang's avatar Chao Yang Committed by GitHub
Browse files

update llama4 chat template and pythonic parser (#6679)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent b581b225
{# Copied from https://github.com/yeqcharlotte/vllm/blob/4fcf68a948bbe0498dc8a98feafa102cfb1dd210/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
{# Copied from https://github.com/wukaixingxp/vllm/blob/8a32e2a6e452a03c0e8222e3876ad6086cbf581f/examples/tool_chat_template_llama4_pythonic.jinja to enable better model response. #}
{{- bos_token }}
{%- if custom_tools is defined %}
{%- if custom_tools is defined and custom_tools %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = false %}
{%- endif %}
{%- if not tools is defined %}
{%- if tools is defined and tools %}
{%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %}
{%- else %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set user_provided_system_message = true %}
{%- if messages[0]['content'] is string %}
{%- set system_message = messages[0]['content']|trim %}
{%- else %}
......@@ -19,68 +20,33 @@
{%- endif %}
{%- set messages = messages[1:] %}
{%- else %}
{%- if tools is not none %}
{#- Add default tool system message when tools are provided #}
{%- set system_message = "You are a helpful assistant with tool calling "
"capabilities. Only reply with a tool call if the function exists in the "
"library provided by the user. If it doesn't exist, just reply directly in "
"natural language. When you receive a tool call response, use the output to "
"format an answer to the original user question." %}
{%- if tools is not none %}
{#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #}
{#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #}
{%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}
{%- endif %}
{#- System message if the user supplied one, or if tools are used (default tool system message) #}
{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #}
{%- if system_message %}
{#- always use user provided system message to override default tool system message #}
{{- "<|header_start|>system<|header_end|>\n\n" }}
{{- system_message }}
{%- if tools is not none and not tools_in_user_message %}
{{- "Tools: You have access to the following tools. You might need to use one "
"or more function/tool calls to fulfill the task. \n"
"If none are needed, then proceed to the response.\n\n"
"Tool Call Syntax: You can call tools using the following syntax:\n"
"[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n"
"Do not include anything else when calling the tools with the syntax above.\n\n"
"Here is a list of functions in JSON format that you can invoke.\n " }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- if user_provided_system_message and tools %}
{{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }}
{{- tool_definition -}}
{%- elif tool_definition %}
{{- tool_definition -}}
{%- endif %}
{{- "<|eot|>" }}
{%- endif %}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and tools is not none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- if messages[0]['content'] is string %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- else %}
{%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
{%- endif %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|header_start|>user<|header_end|>\n\n' -}}
{{- first_user_message}}
{{- "\nHere is a list of functions in JSON format that you can invoke:"}}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- "Should you decide to return the function call(s), put them in the format "
"of [func_name1(params_name1=params_value1, params_name2=params_value2, "
"...), ...]\nDo not include anything else when calling the tools with the "
"syntax above." }}
{%- endif %}
{#- Now deal with all other messages #}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
{#- Base case: messages that are not from tool role and has empty tool_call list #}
{%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %}
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
......@@ -92,10 +58,12 @@
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "<|eot|>" }}
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
{%- set tool_call = message.tool_calls[0].function %}
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
{{- "<|eot|>" }}
{#- Tool case: messages has non-empty tool_call list, must from assistant #}
{%- elif 'tool_calls' in message %}
{#- assume tool_calls are always coming from assistant #}
{%- if message.role == 'assistant' %}
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
......@@ -107,20 +75,24 @@
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "[" }}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- tool_call.name + '(' -}}
{{- tool_call.name + '(' -}}
{%- for param in tool_call.arguments %}
{{- param + '=' -}}
{{- param + '="' -}}
{{- "%s" | format(tool_call.arguments[param]) -}}
{{- '"' -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- ')' -}}
{% if not loop.last %}, {% endif %}
{%- endfor %}
{{- "<|eom|>" }}
{{- "]<|eot|>" }}
{%- endif %}
{#- Tool_response case: messages are from tool_response #}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
{%- if message.content is string %}
......@@ -132,7 +104,7 @@
{%- endif %}
{%- endfor %}
{%- endif %}
{{- "<|eom|>" }}
{{- "<|eot|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
......
......@@ -32,13 +32,24 @@ class PythonicDetector(BaseFormatDetector):
re.DOTALL,
)
@staticmethod
def _text_strip(text: str) -> str:
# Llama 4 model sometime will output <|python_start|> and <|python_end|> tokens
# remove those tokens
text = text.replace("<|python_start|>", "")
text = text.replace("<|python_end|>", "")
return text
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.search(text.strip()))
return bool(self.tool_call_regex.search(self._text_strip(text.strip())))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
# Remove unexpected <|python_start|> and <|python_end|> for llama4
text = self._text_strip(text)
match = self.tool_call_regex.search(text)
if match is None:
return StreamingParseResult(normal_text=text, calls=[])
......@@ -117,6 +128,30 @@ class PythonicDetector(BaseFormatDetector):
return i
return -1 # No matching bracket found
def _strip_and_split_buffer(self, buffer: str) -> tuple[str, str]:
"""
Strip special tokens from buffer and split into safe_text and held_back_text.
Returns:
tuple of (safe_text_to_output, text_to_hold_in_buffer)
"""
# Check if original buffer ends with a partial token at the end
special_tokens = ["<|python_start|>", "<|python_end|>"]
for token in special_tokens:
partial_length = self._ends_with_partial_token(buffer, token)
if partial_length > 0:
# Split buffer: safe part + held back partial token
safe_text = buffer[:-partial_length]
held_back = buffer[-partial_length:]
# Strip complete special tokens from safe part only
safe_text = self._text_strip(safe_text)
return safe_text, held_back
# No partial tokens found, strip complete tokens from entire buffer
safe_text = self._text_strip(buffer)
return safe_text, ""
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
......@@ -126,20 +161,28 @@ class PythonicDetector(BaseFormatDetector):
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
# Strip special tokens from entire buffer and handle partial tokens
stripped_buffer, held_back = self._strip_and_split_buffer(self._buffer)
start = stripped_buffer.find("[")
if start == -1:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
# No tool call bracket found
self._buffer = held_back
return StreamingParseResult(normal_text=stripped_buffer)
normal_text = self._buffer[:start] if start > 0 else ""
normal_text = stripped_buffer[:start] if start > 0 else ""
end = self._find_matching_bracket(self._buffer, start)
end = self._find_matching_bracket(stripped_buffer, start)
if end != -1:
call_text = self._buffer[start : end + 1]
# Found complete tool call
call_text = stripped_buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
# Update buffer with remaining text after tool call plus any held back text
remaining_text = stripped_buffer[end + 1 :] + held_back
self._buffer = remaining_text
# If we had normal text before the tool call, add it to the result
if normal_text:
......@@ -148,8 +191,10 @@ class PythonicDetector(BaseFormatDetector):
return result
# We have an opening bracket but no closing bracket yet
# Put back everything from the bracket onwards plus held back text
self._buffer = stripped_buffer[start:] + held_back
if normal_text:
self._buffer = self._buffer[start:]
return StreamingParseResult(normal_text=normal_text)
# Otherwise, we're still accumulating a potential tool call
......
......@@ -265,6 +265,81 @@ class TestPythonicDetector(unittest.TestCase):
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
def test_parse_streaming_with_python_start_and_end_token(self):
"""Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks."""
chunks = [
"Here's a call: ",
"<|python_",
"start|>[get_weather(location=",
"'Tokyo', data=[1, 2",
", 3])]<|python_end|>",
]
normal_text = ""
call_name = ""
parameters = ""
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
if result.normal_text:
normal_text += result.normal_text
if result.calls:
call_name += result.calls[0].name
parameters += result.calls[0].parameters
self.assertEqual(normal_text, "Here's a call: ")
self.assertEqual(call_name, "get_weather")
self.assertEqual(self.detector._buffer, "")
self.assertEqual(
result.normal_text, "", "Final result should have no normal text"
)
# Check the parameters
params = json.loads(parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
chunks = [
"Here's a call: <|python_start|>[get_weather(location='Tokyo', data=[1, 2, 3])]<|python_end|>"
]
normal_text = ""
call_name = ""
parameters = ""
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
if result.normal_text:
normal_text += result.normal_text
if result.calls:
call_name += result.calls[0].name
parameters += result.calls[0].parameters
self.assertEqual(normal_text, "Here's a call: ")
self.assertEqual(call_name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
def test_detect_and_parse_with_python_start_and_end_token(self):
"""Test parsing a message that starts with <|python_start|> and contains a valid tool call."""
text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars."
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(
result.normal_text,
"User wants to get the weather in Mars. In this way we will get the weather in Mars.",
)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "Mars")
self.assertEqual(params["unit"], "celsius")
class TestMistralDetector(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment