Unverified Commit e36dfaa8 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: allow tool calling to respond without a tool (#2614)



* feat: process token stream before returning to client

* fix: expect content in test

* fix: improve comparison via ruff lint

* fix: return event in all cases

* fix: always send event on error, avoid unwraps, refactor and improve tests

* fix: prefer no_tool over notify_error to improve reponse

* fix: adjust chat input test for no_tool

* fix: adjust test expected content

---------
Co-authored-by: default avatarSystem administrator <root@ip-10-90-0-186.ec2.internal>
parent 43f39f68
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": null, "content": "I am an AI assistant",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": [ "tool_calls": null
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1712852597, "created": 1728497062,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.3.2-dev0-native",
"usage": { "usage": {
"completion_tokens": 39, "completion_tokens": 23,
"prompt_tokens": 496, "prompt_tokens": 604,
"total_tokens": 535 "total_tokens": 627
} }
} }
{
"choices": [
{
"delta": {
"content": " assistant",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497531,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
...@@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream( ...@@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
) )
count = 0 count = 0
tool_calls_generated = ""
last_response = None
async for response in responses: async for response in responses:
count += 1 count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28 assert count == 28
assert response == response_snapshot assert last_response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information( ...@@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "You're a helpful assistant! Answer the users question best you can.",
}, },
{ {
"role": "user", "role": "user",
"content": "Tell me a story about 3 sea creatures", "content": "Who are you?",
}, },
], ],
stream=False, stream=False,
) )
assert responses.choices[0].message.content is None assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert ( assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error" content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
) )
assert responses == response_snapshot assert last_response == response_snapshot
...@@ -355,6 +355,8 @@ pub enum InferError { ...@@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String), MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
} }
impl InferError { impl InferError {
...@@ -368,6 +370,7 @@ impl InferError { ...@@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
} }
} }
} }
...@@ -31,25 +31,25 @@ impl ToolGrammar { ...@@ -31,25 +31,25 @@ impl ToolGrammar {
let mut tools = tools.clone(); let mut tools = tools.clone();
// add the notify_error function to the tools // add the no_tool function to the tools
let notify_error = Tool { let no_tool = Tool {
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
name: "notify_error".to_string(), name: "no_tool".to_string(),
description: Some("Notify an error or issue".to_string()), description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({ arguments: json!({
"type": "object", "type": "object",
"properties": { "properties": {
"error": { "content": {
"type": "string", "type": "string",
"description": "The error or issue to notify" "description": "The response content",
} }
}, },
"required": ["error"] "required": ["content"]
}), }),
}, },
}; };
tools.push(notify_error); tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment