tool_grammar.rs 4.76 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;

pub(crate) struct ToolGrammar {}

impl ToolGrammar {
    // find a tool by name
    fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
        tools
            .iter()
            .find(|tool| tool.function.name == name)
            .cloned()
            .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
    }

    pub fn apply(
        tools: Option<Vec<Tool>>,
        tool_choice: ToolChoice,
    ) -> Result<Option<Tools>, InferError> {
        // if no tools are provided, we return None
        let tools = match tools {
            Some(tools) if !tools.is_empty() => tools,
            _ => return Ok(None),
        };

        let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);

        // if tools are provided and no tool_choice we default to the OneOf
        let tools_to_use = match tool_choice {
            ToolType::FunctionName(name) => {
                vec![Self::find_tool_by_name(&tools, &name)?]
            }
            ToolType::Function { function } => {
                vec![Self::find_tool_by_name(&tools, &function.name)?]
            }
            ToolType::OneOf => tools,
            ToolType::NoTool => return Ok(None),
        };

        // adds the error notification function for LLM feedback if required
        let mut text_response_properties = Map::new();
        text_response_properties.insert(
            "error".to_string(),
            serde_json::json!({
                "type": "string",
                "description": "The error or issue to notify"
            }),
        );
        text_response_properties.insert(
            "_name".to_string(),
            serde_json::json!({
                "type": "string",
                "const": "notify_error"
            }),
        );

        let functions: HashMap<String, serde_json::Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();

                // Clone the existing parameters, which are expected to be a JSON object
                let mut params = if let Value::Object(params) = &func.arguments {
                    params.clone()
                } else {
                    Map::new()
                };

                // Insert the function's description at the top level, outside of properties
                params.insert(
                    "description".to_string(),
                    Value::String(func.description.clone().unwrap_or_default()),
                );

                // Ensure 'properties' exists and is an object
                let properties = params
                    .entry("properties".to_string())
                    .or_insert_with(|| json!({}))
                    .as_object_mut()
                    .unwrap();

                // Insert the constant for the function name inside 'properties'
                properties.insert(
                    "_name".to_string(),
                    json!({
                        "type": "string",
                        "const": func.name.clone(),
                        // "description": "The name of the function"
                    }),
                );

                // Check if 'required' exists, and it is an array. If not, create an empty array.
                let required = params
                    .entry("required".to_string())
                    .or_insert_with(|| json!([]))
                    .as_array_mut()
                    .unwrap();

                // Add 'name' to the 'required' array if it is not already present
                if !required.iter().any(|r| r == "_name") {
                    required.push(json!("_name"));
                }

                (func.name, Value::Object(params))
            })
            .chain([(
                "notify_error".to_string(),
                serde_json::json!({
                    "properties": text_response_properties,
                    "required": ["error", "_name"],
                    "type": "object"
                }),
            )])
            .collect();

        let tools = Tools {
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: tools_to_use
                    .iter()
                    .map(|tool| FunctionRef {
                        ref_path: format!("#/$functions/{}", tool.function.name.clone()),
                    })
                    .chain(std::iter::once(FunctionRef {
                        ref_path: "#/$functions/notify_error".to_string(),
                    }))
                    .collect(),
            },
        };

        Ok(Some(tools))
    }
}