tool_grammar.rs 4.18 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
use crate::infer::InferError;
drbh's avatar
drbh committed
2
3
4
5
use crate::{
    FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
    ToolType,
};
Nicolas Patry's avatar
Nicolas Patry committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
use serde_json::{json, Map, Value};
use std::collections::HashMap;

pub(crate) struct ToolGrammar {}

impl ToolGrammar {
    // find a tool by name
    fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
        tools
            .iter()
            .find(|tool| tool.function.name == name)
            .cloned()
            .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
    }

    pub fn apply(
drbh's avatar
drbh committed
22
        tools: Vec<Tool>,
Nicolas Patry's avatar
Nicolas Patry committed
23
        tool_choice: ToolChoice,
drbh's avatar
drbh committed
24
    ) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
Nicolas Patry's avatar
Nicolas Patry committed
25
        // if no tools are provided, we return None
drbh's avatar
drbh committed
26
27
28
        if tools.is_empty() {
            return Ok((tools, None));
        }
Nicolas Patry's avatar
Nicolas Patry committed
29
30
31

        let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);

drbh's avatar
drbh committed
32
33
        let mut tools = tools.clone();

34
35
        // add the no_tool function to the tools
        let no_tool = Tool {
drbh's avatar
drbh committed
36
37
            r#type: "function".to_string(),
            function: FunctionDefinition {
38
39
                name: "no_tool".to_string(),
                description: Some("Open ened response with no specific tool selected".to_string()),
drbh's avatar
drbh committed
40
41
42
                arguments: json!({
                    "type": "object",
                    "properties": {
43
                        "content": {
drbh's avatar
drbh committed
44
                            "type": "string",
45
                            "description": "The response content",
drbh's avatar
drbh committed
46
47
                        }
                    },
48
                    "required": ["content"]
drbh's avatar
drbh committed
49
50
51
                }),
            },
        };
52
        tools.push(no_tool);
drbh's avatar
drbh committed
53

Nicolas Patry's avatar
Nicolas Patry committed
54
55
        // if tools are provided and no tool_choice we default to the OneOf
        let tools_to_use = match tool_choice {
56
            ToolType::Function(function) => {
Nicolas Patry's avatar
Nicolas Patry committed
57
58
                vec![Self::find_tool_by_name(&tools, &function.name)?]
            }
drbh's avatar
drbh committed
59
60
            ToolType::OneOf => tools.clone(),
            ToolType::NoTool => return Ok((tools, None)),
Nicolas Patry's avatar
Nicolas Patry committed
61
62
63
64
65
66
67
        };

        let functions: HashMap<String, serde_json::Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();

drbh's avatar
drbh committed
68
                let mut params = Map::new();
Nicolas Patry's avatar
Nicolas Patry committed
69
70
71

                params.insert(
                    "description".to_string(),
drbh's avatar
drbh committed
72
                    Value::String(func.description.unwrap_or_default()),
Nicolas Patry's avatar
Nicolas Patry committed
73
74
                );

drbh's avatar
drbh committed
75
76
                let mut properties = Map::new();
                let mut required = vec![Value::String("_name".to_string())];
Nicolas Patry's avatar
Nicolas Patry committed
77
78
79
80
81
82
83
84
85

                properties.insert(
                    "_name".to_string(),
                    json!({
                        "type": "string",
                        "const": func.name.clone(),
                    }),
                );

drbh's avatar
drbh committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
                if let Value::Object(args) = func.arguments {
                    if let Some(Value::Object(props)) = args.get("properties") {
                        properties.extend(props.clone());
                    }
                    if let Some(Value::Array(reqs)) = args.get("required") {
                        required.extend(reqs.clone());
                    }
                    params.insert(
                        "additionalProperties".to_string(),
                        Value::Bool(
                            args.get("additionalProperties").and_then(|v| v.as_str())
                                == Some("true"),
                        ),
                    );
Nicolas Patry's avatar
Nicolas Patry committed
100
101
                }

drbh's avatar
drbh committed
102
103
104
                params.insert("properties".to_string(), Value::Object(properties));
                params.insert("required".to_string(), Value::Array(required));

Nicolas Patry's avatar
Nicolas Patry committed
105
106
107
108
                (func.name, Value::Object(params))
            })
            .collect();

drbh's avatar
drbh committed
109
        let tool_schema = JsonSchemaTool {
Nicolas Patry's avatar
Nicolas Patry committed
110
111
112
113
114
115
116
117
118
119
120
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: tools_to_use
                    .iter()
                    .map(|tool| FunctionRef {
                        ref_path: format!("#/$functions/{}", tool.function.name.clone()),
                    })
                    .collect(),
            },
        };

drbh's avatar
drbh committed
121
        Ok((tools, Some(tool_schema)))
Nicolas Patry's avatar
Nicolas Patry committed
122
123
    }
}