tools_utils.go 5.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package tools

import (
	"bytes"
	"encoding/json"
	"errors"
	"log/slog"
	"slices"
	"strings"
	gotmpl "text/template"
	"text/template/parse"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/template"
)

// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
//
// Returns:
//   - string: The extracted text following the first ".ToolCalls" condition found
//   - bool: Whether a ".ToolCalls" condition was found in the template
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
	if tmpl == nil || tmpl.Tree == nil {
		slog.Debug("template or tree is nil")
		return "", false
	}

	var result string
	var found bool

	var walk func(nodes []parse.Node)
	walk = func(nodes []parse.Node) {
		for _, node := range nodes {
			if found {
				return
			}

			switch n := node.(type) {
			case *parse.IfNode:
				if isToolCallsNode(n) {
					// Collect immediate TextNode(s) at start of IfNode's list
					var sb strings.Builder
					for _, innerNode := range n.List.Nodes {
						if tn, ok := innerNode.(*parse.TextNode); ok {
							sb.Write(tn.Text)
						} else {
							// Stop at first non-text node
							break
						}
					}
					result = sb.String()
					found = true
					return
				}
				// Recurse into child nodes
				walk(n.List.Nodes)
				if n.ElseList != nil {
					walk(n.ElseList.Nodes)
				}
			case *parse.ListNode:
				walk(n.Nodes)
			case *parse.RangeNode:
				walk(n.List.Nodes)
				if n.ElseList != nil {
					walk(n.ElseList.Nodes)
				}
			case *parse.WithNode:
				walk(n.List.Nodes)
				if n.ElseList != nil {
					walk(n.ElseList.Nodes)
				}
			default:
				// Continue to next node
				continue
			}
		}
	}

	walk(tmpl.Tree.Root.Nodes)
	return result, found
}

// isToolCallsNode detects if a node's condition includes ".ToolCalls"
func isToolCallsNode(n *parse.IfNode) bool {
	for _, cmd := range n.Pipe.Cmds {
		for _, arg := range cmd.Args {
			if field, ok := arg.(*parse.FieldNode); ok {
				if slices.Contains(field.Ident, "ToolCalls") {
					return true
				}
			}
		}
	}
	return false
}

func toolPrefix(tmpl *gotmpl.Template) string {
	tokenText, ok := extractToolCallsFormat(tmpl)
	if !ok {
		return ""
	}
	tokenText = strings.TrimSpace(tokenText)
	tokenText = strings.ReplaceAll(tokenText, "\r", "")
	tokenText = strings.ReplaceAll(tokenText, "\n", " ")

	return tokenText
}

// toolTemplate creates a subtree from the node that ranges over .ToolCalls
//
// Returns:
//   - *gotmpl.Template: The subtree containing the .ToolCalls range
//   - error: Error if parsing failed
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
	tmpl := t.Subtree(func(n parse.Node) bool {
		if t, ok := n.(*parse.RangeNode); ok {
			return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
		}

		return false
	})

	if tmpl == nil {
		return nil, errors.New("failed to find tool template")
	}

	return tmpl, nil
}

// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
//
// Returns:
//   - int: The starting index in s where the suffix overlap begins
func suffixOverlap(s, prefix string) int {
	max := min(len(prefix), len(s))
	for i := max; i > 0; i-- {
		if strings.HasSuffix(s, prefix[:i]) {
			return len(s) - i
		}
	}
	return -1
}

// extractToolArgs executes a template with a known tool call format to extract the name and arguments
//
// Returns:
//   - string: The name of the tool call
//   - string: The arguments of the tool call
//   - error: Error if parsing failed
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
	var b bytes.Buffer
	if err := tmpl.Execute(&b, map[string][]api.ToolCall{
		"ToolCalls": {
			{
				Function: api.ToolCallFunction{
					Name: "@@name@@",
					Arguments: api.ToolCallFunctionArguments{
						"@@argument@@": 1,
					},
				},
			},
		},
	}); err != nil {
		return "", "", err
	}

169
170
171
172
173
174
175
	// Extract JSON object between curly braces
	// JSON arrays are also valid as they will not be repeated in the template
	output := b.String()
	start := strings.Index(output, "{")
	end := strings.LastIndex(output, "}")
	if start == -1 || end == -1 || start > end {
		return "", "", errors.New("no valid JSON object found in template output")
176
	}
177
	jsonStr := output[start : end+1]
178

179
180
181
	var obj map[string]any
	if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
		return "", "", err
182
183
	}

184
185
186
	// Find name and arguments fields
	for k, v := range obj {
		if str, ok := v.(string); ok && str == "@@name@@" {
187
			name = k
188
		} else if _, ok := v.(map[string]any); ok {
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
			arguments = k
		}
	}

	if name == "" || arguments == "" {
		slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
		return "", "", errors.New("missing required fields in tool call template")
	}

	return name, arguments, nil
}

// collect recursively traverses an object to collect all nested maps
//
// Returns:
//   - []map[string]any: A slice of all nested maps found in the object
func collect(obj any) []map[string]any {
	var all []map[string]any
	switch o := obj.(type) {
	case map[string]any:
		all = append(all, o)
		for _, v := range o {
			all = append(all, collect(v)...)
		}
	case []any:
		for _, v := range o {
			all = append(all, collect(v)...)
		}
	default:
		return nil
	}

	return all
}