Unverified Commit e36dfaa8 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: allow tool calling to respond without a tool (#2614)



* feat: process token stream before returning to client

* fix: expect content in test

* fix: improve comparison via ruff lint

* fix: return event in all cases

* fix: always send event on error, avoid unwraps, refactor and improve tests

* fix: prefer no_tool over notify_error to improve reponse

* fix: adjust chat input test for no_tool

* fix: adjust test expected content

---------
Co-authored-by: default avatarSystem administrator <root@ip-10-90-0-186.ec2.internal>
parent 43f39f68
{
"choices": [
{
"finish_reason": "eos_token",
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"content": "I am an AI assistant",
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
"tool_calls": null
},
"usage": null
}
],
"created": 1712852597,
"created": 1728497062,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.2-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
"completion_tokens": 23,
"prompt_tokens": 604,
"total_tokens": 627
}
}
{
"choices": [
{
"delta": {
"content": " assistant",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497531,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
{
"choices": [
{
"delta": {
"content": " fans",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1728497461,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.3.2-dev0-native",
"usage": null
}
......@@ -207,11 +207,20 @@ async def test_flash_llama_grammar_tools_stream(
)
count = 0
tool_calls_generated = ""
last_response = None
async for response in responses:
count += 1
tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
last_response = response
assert response.choices[0].delta.content is None
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Paris, France"}}<|eot_id|>'
)
assert count == 28
assert response == response_snapshot
assert last_response == response_snapshot
@pytest.mark.asyncio
......@@ -227,18 +236,94 @@ async def test_flash_llama_grammar_tools_insufficient_information(
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
"content": "Who are you?",
},
],
stream=False,
)
assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls is None
assert responses.choices[0].message.content == "I am an AI assistant"
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "Who are you?",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 5
assert content_generated == "I am an AI assistant"
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=True,
)
count = 0
content_generated = ""
last_response = None
async for response in responses:
count += 1
content_generated += response.choices[0].delta.content
last_response = response
assert response.choices[0].delta.tool_calls is None
assert count == 62
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
content_generated
== "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
)
assert responses == response_snapshot
assert last_response == response_snapshot
......@@ -355,6 +355,8 @@ pub enum InferError {
MissingTemplateVariable(String),
#[error("Tool error: {0}")]
ToolError(String),
#[error("Stream event serialization error")]
StreamSerializationError(String),
}
impl InferError {
......@@ -368,6 +370,7 @@ impl InferError {
InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error",
InferError::StreamSerializationError(_) => "stream_serialization_error",
}
}
}
......@@ -31,25 +31,25 @@ impl ToolGrammar {
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
// add the no_tool function to the tools
let no_tool = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
name: "no_tool".to_string(),
description: Some("Open ened response with no specific tool selected".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"content": {
"type": "string",
"description": "The error or issue to notify"
"description": "The response content",
}
},
"required": ["error"]
"required": ["content"]
}),
},
};
tools.push(notify_error);
tools.push(no_tool);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment