"vscode:/vscode.git/clone" did not exist on "e205527cb11148b19ba4061d8503e7866c3f25dd"
tool_grammar.rs 4.47 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
use crate::infer::InferError;
drbh's avatar
drbh committed
2
3
4
use crate::{
    FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
};
Nicolas Patry's avatar
Nicolas Patry committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use serde_json::{json, Map, Value};
use std::collections::HashMap;

pub(crate) struct ToolGrammar {}

impl ToolGrammar {
    // find a tool by name
    fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
        tools
            .iter()
            .find(|tool| tool.function.name == name)
            .cloned()
            .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
    }

    pub fn apply(
drbh's avatar
drbh committed
21
        tools: Vec<Tool>,
Nicolas Patry's avatar
Nicolas Patry committed
22
        tool_choice: ToolChoice,
23
    ) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {
Nicolas Patry's avatar
Nicolas Patry committed
24
        let tools_to_use = match tool_choice {
25
            ToolChoice::Function(function) => {
Nicolas Patry's avatar
Nicolas Patry committed
26
27
                vec![Self::find_tool_by_name(&tools, &function.name)?]
            }
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            ToolChoice::Required => tools,
            ToolChoice::Auto => {
                // only add the no_tool function if the user has selected the auto option
                tools
                    .iter()
                    .cloned()
                    .chain(std::iter::once(Tool {
                        r#type: "function".to_string(),
                        function: FunctionDefinition {
                            name: "no_tool".to_string(),
                            description: Some(
                                "Open ended response with no specific tool selected".to_string(),
                            ),
                            arguments: json!({
                                "type": "object",
                                "properties": {
                                    "content": {
                                        "type": "string",
                                        "description": "The response content",
                                    }
                                },
                                "required": ["content"]
                            }),
                        },
                    }))
                    .collect::<Vec<_>>()
            }
            ToolChoice::NoTool => vec![],
Nicolas Patry's avatar
Nicolas Patry committed
56
57
        };

58
59
60
61
62
        // if no tools are provided or if the user has selected the no_tool option, return None
        if tools_to_use.is_empty() {
            return Ok(None);
        }

Nicolas Patry's avatar
Nicolas Patry committed
63
64
65
66
67
        let functions: HashMap<String, serde_json::Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();

drbh's avatar
drbh committed
68
                let mut params = Map::new();
Nicolas Patry's avatar
Nicolas Patry committed
69
70
71

                params.insert(
                    "description".to_string(),
drbh's avatar
drbh committed
72
                    Value::String(func.description.unwrap_or_default()),
Nicolas Patry's avatar
Nicolas Patry committed
73
74
                );

drbh's avatar
drbh committed
75
76
                let mut properties = Map::new();
                let mut required = vec![Value::String("_name".to_string())];
Nicolas Patry's avatar
Nicolas Patry committed
77
78
79
80
81
82
83
84
85

                properties.insert(
                    "_name".to_string(),
                    json!({
                        "type": "string",
                        "const": func.name.clone(),
                    }),
                );

drbh's avatar
drbh committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
                if let Value::Object(args) = func.arguments {
                    if let Some(Value::Object(props)) = args.get("properties") {
                        properties.extend(props.clone());
                    }
                    if let Some(Value::Array(reqs)) = args.get("required") {
                        required.extend(reqs.clone());
                    }
                    params.insert(
                        "additionalProperties".to_string(),
                        Value::Bool(
                            args.get("additionalProperties").and_then(|v| v.as_str())
                                == Some("true"),
                        ),
                    );
Nicolas Patry's avatar
Nicolas Patry committed
100
101
                }

drbh's avatar
drbh committed
102
103
104
                params.insert("properties".to_string(), Value::Object(properties));
                params.insert("required".to_string(), Value::Array(required));

Nicolas Patry's avatar
Nicolas Patry committed
105
106
107
108
                (func.name, Value::Object(params))
            })
            .collect();

drbh's avatar
drbh committed
109
        let tool_schema = JsonSchemaTool {
Nicolas Patry's avatar
Nicolas Patry committed
110
111
112
113
114
115
116
117
118
119
120
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: tools_to_use
                    .iter()
                    .map(|tool| FunctionRef {
                        ref_path: format!("#/$functions/{}", tool.function.name.clone()),
                    })
                    .collect(),
            },
        };

121
        Ok(Some((tools_to_use, tool_schema)))
Nicolas Patry's avatar
Nicolas Patry committed
122
123
    }
}