Unverified Commit d22b0c1f authored by drbh's avatar drbh Committed by GitHub
Browse files

Unroll notify error into generate response (#2597)

* feat: unroll notify_error if no tool is choosen

* fix: expect simple message when no tool is selected

* fix: improve test to avoid notify_error

* fix: improve docs and indicate change in expected response

* fix: adjust linting in test file
parent 23354595
...@@ -311,11 +311,13 @@ print(chat.choices[0].message.tool_calls) ...@@ -311,11 +311,13 @@ print(chat.choices[0].message.tool_calls)
``` ```
### OpenAI integration ### OpenAI Integration
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. Text Generation Inference (TGI) offers seamless integration with OpenAI's client libraries, allowing developers to interact with TGI's Messages API and Tool functions in a familiar way. This compatibility simplifies the implementation of advanced features, such as tools and grammar, within your applications using OpenAI’s client.
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. Previously, TGI handled tool selection differently than OpenAI’s API—`tool_choice="auto"` would always pick a tool for you. However, as of the latest version, TGI now mimics OpenAI’s behavior more closely: `tool_choice="auto"` selects a tool only when the model deems it necessary, aligning with how OpenAI's API works. This enhancement ensures a smoother and more predictable integration experience.
Additionally, error notifications like `notify_error`, which previously indicated that no tool was chosen, are no longer returned. Instead, TGI will proceed with generating a response as if no tool was selected, further improving consistency with OpenAI's API.
```python ```python
from openai import OpenAI from openai import OpenAI
......
{ {
"choices": [ "choices": [
{ {
"finish_reason": "eos_token", "finish_reason": "stop",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": null, "content": "There is a huge storm in the ocean",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": [ "tool_calls": null
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1712852597, "created": 1727796440,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "1.4.5-native", "system_fingerprint": "2.3.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 39, "completion_tokens": 25,
"prompt_tokens": 496, "prompt_tokens": 600,
"total_tokens": 535 "total_tokens": 625
} }
} }
...@@ -225,10 +225,6 @@ async def test_flash_llama_grammar_tools_insufficient_information( ...@@ -225,10 +225,6 @@ async def test_flash_llama_grammar_tools_insufficient_information(
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
messages=[ messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{ {
"role": "user", "role": "user",
"content": "Tell me a story about 3 sea creatures", "content": "Tell me a story about 3 sea creatures",
...@@ -237,8 +233,5 @@ async def test_flash_llama_grammar_tools_insufficient_information( ...@@ -237,8 +233,5 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False, stream=False,
) )
assert responses.choices[0].message.content is None assert responses.choices[0].message.content == "There is a huge storm in the ocean"
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
)
assert responses == response_snapshot assert responses == response_snapshot
...@@ -1246,7 +1246,21 @@ async fn chat_completions( ...@@ -1246,7 +1246,21 @@ async fn chat_completions(
if let Value::Object(ref mut props) = arguments { if let Value::Object(ref mut props) = arguments {
props.remove("_name"); props.remove("_name");
} }
match name.as_str() {
"notify_error" => {
// parse the error message
let error_message = arguments
.get("error")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No error message found in generated text".to_string(),
)
})?
.to_string();
(None, Some(error_message))
}
_ => {
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: "0".to_string(), id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
...@@ -1257,6 +1271,8 @@ async fn chat_completions( ...@@ -1257,6 +1271,8 @@ async fn chat_completions(
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
}
}
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment