tool_grammar.rs 4.3 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
use crate::infer::InferError;
drbh's avatar
drbh committed
2
3
4
5
use crate::{
    FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
    ToolType,
};
Nicolas Patry's avatar
Nicolas Patry committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
use serde_json::{json, Map, Value};
use std::collections::HashMap;

pub(crate) struct ToolGrammar {}

impl ToolGrammar {
    // find a tool by name
    fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
        tools
            .iter()
            .find(|tool| tool.function.name == name)
            .cloned()
            .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
    }

    pub fn apply(
drbh's avatar
drbh committed
22
        tools: Vec<Tool>,
Nicolas Patry's avatar
Nicolas Patry committed
23
        tool_choice: ToolChoice,
drbh's avatar
drbh committed
24
    ) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
Nicolas Patry's avatar
Nicolas Patry committed
25
        // if no tools are provided, we return None
drbh's avatar
drbh committed
26
27
28
        if tools.is_empty() {
            return Ok((tools, None));
        }
Nicolas Patry's avatar
Nicolas Patry committed
29
30
31

        let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);

drbh's avatar
drbh committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        let mut tools = tools.clone();

        // add the notify_error function to the tools
        let notify_error = Tool {
            r#type: "function".to_string(),
            function: FunctionDefinition {
                name: "notify_error".to_string(),
                description: Some("Notify an error or issue".to_string()),
                arguments: json!({
                    "type": "object",
                    "properties": {
                        "error": {
                            "type": "string",
                            "description": "The error or issue to notify"
                        }
                    },
                    "required": ["error"]
                }),
            },
        };
        tools.push(notify_error);

Nicolas Patry's avatar
Nicolas Patry committed
54
55
56
57
58
59
60
61
        // if tools are provided and no tool_choice we default to the OneOf
        let tools_to_use = match tool_choice {
            ToolType::FunctionName(name) => {
                vec![Self::find_tool_by_name(&tools, &name)?]
            }
            ToolType::Function { function } => {
                vec![Self::find_tool_by_name(&tools, &function.name)?]
            }
drbh's avatar
drbh committed
62
63
            ToolType::OneOf => tools.clone(),
            ToolType::NoTool => return Ok((tools, None)),
Nicolas Patry's avatar
Nicolas Patry committed
64
65
66
67
68
69
70
        };

        let functions: HashMap<String, serde_json::Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();

drbh's avatar
drbh committed
71
                let mut params = Map::new();
Nicolas Patry's avatar
Nicolas Patry committed
72
73
74

                params.insert(
                    "description".to_string(),
drbh's avatar
drbh committed
75
                    Value::String(func.description.unwrap_or_default()),
Nicolas Patry's avatar
Nicolas Patry committed
76
77
                );

drbh's avatar
drbh committed
78
79
                let mut properties = Map::new();
                let mut required = vec![Value::String("_name".to_string())];
Nicolas Patry's avatar
Nicolas Patry committed
80
81
82
83
84
85
86
87
88

                properties.insert(
                    "_name".to_string(),
                    json!({
                        "type": "string",
                        "const": func.name.clone(),
                    }),
                );

drbh's avatar
drbh committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                if let Value::Object(args) = func.arguments {
                    if let Some(Value::Object(props)) = args.get("properties") {
                        properties.extend(props.clone());
                    }
                    if let Some(Value::Array(reqs)) = args.get("required") {
                        required.extend(reqs.clone());
                    }
                    params.insert(
                        "additionalProperties".to_string(),
                        Value::Bool(
                            args.get("additionalProperties").and_then(|v| v.as_str())
                                == Some("true"),
                        ),
                    );
Nicolas Patry's avatar
Nicolas Patry committed
103
104
                }

drbh's avatar
drbh committed
105
106
107
                params.insert("properties".to_string(), Value::Object(properties));
                params.insert("required".to_string(), Value::Array(required));

Nicolas Patry's avatar
Nicolas Patry committed
108
109
110
111
                (func.name, Value::Object(params))
            })
            .collect();

drbh's avatar
drbh committed
112
        let tool_schema = JsonSchemaTool {
Nicolas Patry's avatar
Nicolas Patry committed
113
114
115
116
117
118
119
120
121
122
123
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: tools_to_use
                    .iter()
                    .map(|tool| FunctionRef {
                        ref_path: format!("#/$functions/{}", tool.function.name.clone()),
                    })
                    .collect(),
            },
        };

drbh's avatar
drbh committed
124
        Ok((tools, Some(tool_schema)))
Nicolas Patry's avatar
Nicolas Patry committed
125
126
    }
}