Unverified Commit 7276d434 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: improve tools to include name and add tests (#1693)

This PR makes tool calling aware of the name of the function selected. 

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
parent 88702d87
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"usage": null "usage": null
} }
], ],
"created": 1710795556, "created": 1712874856,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -11,13 +11,12 @@ ...@@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
...@@ -27,14 +26,14 @@ ...@@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795556, "created": 1712782670,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }
...@@ -11,13 +11,12 @@ ...@@ -11,13 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY", "location": "Brooklyn"
"num_days": 14 },
} "description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
...@@ -27,14 +26,14 @@ ...@@ -27,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712787937,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 29, "completion_tokens": 37,
"prompt_tokens": 316, "prompt_tokens": 524,
"total_tokens": 345 "total_tokens": 561
} }
} }
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"description": null, "arguments": {
"name": "tools",
"parameters": {
"format": "celsius", "format": "celsius",
"location": "New York, NY" "location": "New York, NY"
} },
"description": null,
"name": "get_current_weather"
}, },
"id": 0, "id": 0,
"type": "function" "type": "function"
...@@ -26,14 +26,14 @@ ...@@ -26,14 +26,14 @@
"usage": null "usage": null
} }
], ],
"created": 1710795557, "created": 1712852394,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.0-native", "system_fingerprint": "2.0.0-native",
"usage": { "usage": {
"completion_tokens": 21, "completion_tokens": 48,
"prompt_tokens": 187, "prompt_tokens": 320,
"total_tokens": 208 "total_tokens": 368
} }
} }
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1712852597,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
}
}
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
"logprobs": null "logprobs": null
} }
], ],
"created": 1710795499, "created": 1712788218,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_chat_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_chat(flash_llama_chat_handle):
await flash_llama_chat_handle.health(300)
return flash_llama_chat_handle.client
@pytest.mark.private
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
response = await flash_llama_chat.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
...@@ -71,34 +71,7 @@ tools = [ ...@@ -71,34 +71,7 @@ tools = [
] ]
@pytest.mark.asyncio @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
...@@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna ...@@ -121,23 +94,19 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0,
"type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "get_current_weather",
"parameters": { "arguments": {"format": "celsius", "location": "New York, NY"},
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
}, },
"id": 0,
"type": "function",
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
...@@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto( ...@@ -163,23 +132,20 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0,
"type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "get_current_weather",
"parameters": { "arguments": {"format": "celsius", "location": "New York, NY"},
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
}, },
"id": 0,
"type": "function",
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
...@@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice( ...@@ -209,15 +175,16 @@ async def test_flash_llama_grammar_tools_choice(
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "tools", "name": "get_current_weather",
"parameters": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "New York, NY"},
}, },
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip @pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(
...@@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream( ...@@ -246,5 +213,47 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 20 assert count == 38
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
},
],
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses == response_snapshot
This diff is collapsed.
...@@ -79,7 +79,7 @@ impl HubTokenizerConfig { ...@@ -79,7 +79,7 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[serde(tag = "type", content = "value")] #[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType { pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/). /// A string that represents a [JSON Schema](https://json-schema.org/).
...@@ -669,7 +669,7 @@ pub(crate) struct ChatRequest { ...@@ -669,7 +669,7 @@ pub(crate) struct ChatRequest {
#[serde(default = "default_tool_prompt")] #[serde(default = "default_tool_prompt")]
#[schema( #[schema(
nullable = true, nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
)] )]
pub tool_prompt: Option<String>, pub tool_prompt: Option<String>,
...@@ -682,7 +682,7 @@ pub(crate) struct ChatRequest { ...@@ -682,7 +682,7 @@ pub(crate) struct ChatRequest {
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
Some( Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), "\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
) )
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
...@@ -727,26 +727,26 @@ mod deserialize_tool_choice { ...@@ -727,26 +727,26 @@ mod deserialize_tool_choice {
} }
} }
#[derive(Debug, Deserialize, Serialize, ToSchema)] #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools { pub struct Tools {
#[serde(flatten)] #[serde(flatten)]
functions_map: FunctionsMap, functions_map: FunctionsMap,
properties: Properties, properties: Properties,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionsMap { struct FunctionsMap {
#[serde(rename = "$functions")] #[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>, functions: std::collections::HashMap<String, serde_json::Value>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct FunctionRef { struct FunctionRef {
#[serde(rename = "$ref")] #[serde(rename = "$ref")]
ref_path: String, ref_path: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Properties { struct Properties {
#[serde(serialize_with = "serialize_function")] #[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>, function: Vec<FunctionRef>,
...@@ -767,7 +767,8 @@ pub(crate) struct FunctionDefinition { ...@@ -767,7 +767,8 @@ pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
pub name: String, pub name: String,
pub parameters: serde_json::Value, #[serde(alias = "parameters")]
pub arguments: serde_json::Value,
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
...@@ -779,12 +780,14 @@ pub(crate) struct Tool { ...@@ -779,12 +780,14 @@ pub(crate) struct Tool {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<Message>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
......
use crate::config::Config; use crate::config::Config;
/// HTTP Server logic /// HTTP Server logic
use crate::health::Health; use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
...@@ -15,7 +15,7 @@ use crate::{ ...@@ -15,7 +15,7 @@ use crate::{
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use crate::{FunctionDefinition, ToolCall, ToolType};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
...@@ -29,7 +29,6 @@ use futures::Stream; ...@@ -29,7 +29,6 @@ use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
...@@ -757,19 +756,29 @@ async fn chat_completions( ...@@ -757,19 +756,29 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let ChatRequest {
let max_new_tokens = req.max_tokens.or(Some(100)); logprobs,
let repetition_penalty = req max_tokens,
.presence_penalty messages,
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0) presence_penalty,
.map(|x| x + 2.0); seed,
let logprobs = req.logprobs.unwrap_or(false); stop,
let seed = req.seed; stream,
let stop = req.stop.unwrap_or_default(); tools,
tool_choice,
// apply chat template to flatten the request into a single input tool_prompt,
let mut inputs = match infer.apply_chat_template(req.messages) { ..
Ok(inputs) => inputs, } = req;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let stop = stop.unwrap_or_default();
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
...@@ -783,60 +792,28 @@ async fn chat_completions( ...@@ -783,60 +792,28 @@ async fn chat_completions(
} }
}; };
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { let grammar_with_prompt = tool_grammar
let tool_prompt = req.tool_prompt.unwrap_or_default(); .as_ref()
let tools_to_use = match tool_choice { .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use let typed_grammar = grammar_with_prompt
.iter() .as_ref()
.map(|tool| { .map(|(grammar, _)| grammar.clone());
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools { // apply chat template to flatten the request into a single input
functions_map: FunctionsMap { functions }, let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
properties: Properties { Ok(inputs) => inputs,
function: tools_to_use Err(err) => {
.iter() metrics::increment_counter!("tgi_request_failure", "err" => "validation");
.map(|tool| FunctionRef { tracing::error!("{err}");
ref_path: format!("#/$functions/{}", tool.function.name.clone()), return Err((
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY, StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse { Json(ErrorResponse {
error: e.to_string(), error: err.to_string(),
error_type: "Input validation error".to_string(), error_type: err.error_type().to_string(),
}), }),
) ));
})?; }
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
}; };
// build the request passing some parameters // build the request passing some parameters
...@@ -860,7 +837,7 @@ async fn chat_completions( ...@@ -860,7 +837,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: req.top_logprobs, top_n_tokens: req.top_logprobs,
grammar: tool_grammar.clone(), grammar: typed_grammar,
}, },
}; };
...@@ -943,27 +920,28 @@ async fn chat_completions( ...@@ -943,27 +920,28 @@ async fn chat_completions(
}), }),
) )
})?; })?;
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,
name: "tools".to_string(), name: gen_text_value
parameters: gen_text_value.get("function").map_or_else( .get("function")
|| { .and_then(|f| f.get("_name"))
serde_json::from_str(&generation.generated_text).map_err(|e| { .and_then(|name| name.as_str())
( .unwrap_or("default_function_name")
StatusCode::UNPROCESSABLE_ENTITY, .to_string(),
Json(ErrorResponse { // Serialize the JSON object obtained from "function" to an escaped JSON string
error: e.to_string(), arguments: gen_text_value
error_type: "Input validation error".to_string(), .get("function")
}), .map(|f| {
) let mut f_cloned = f.clone();
if let Value::Object(ref mut props) = f_cloned {
props.remove("_name");
}
f_cloned
}) })
}, .unwrap_or_default(),
|f| Ok(f.clone()),
)?,
}, },
}]; }];
(Some(tool_calls), None) (Some(tool_calls), None)
...@@ -1539,6 +1517,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) { ...@@ -1539,6 +1517,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment