tools.go 6.19 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

29
30
31
32
33
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
34

35
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
36
37
38
	return &Parser{
		tag:   tag,
		tools: tools,
39
	}
40
}
41

42
43
44
45
46
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
47
48
	}

49
	p.buffer = append(p.buffer, s...)
50

51
52
53
54
55
56
57
58
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
59
60
		}

61
62
63
64
65
66
67
68
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
69
70
		}

71
72
		if !found {
			return nil, content
73
		}
74
75

		p.state = toolsState_ToolCalling
76
77
	}

78
79
80
81
82
83
84
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
85
86
	}

87
88
89
90
91
92
93
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
94
95
}

96
97
98
99
100
101
102
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
103
104
	}

105
106
107
108
109
110
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
111
	}
112
113
	return -1, false
}
114

115
116
117
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
118
	var tool *api.Tool
119
120
	var end int = len(p.buffer)
	var i int
121

122
123
124
	// find tool name
	for _, t := range p.tools {
		n := t.Function.Name
125
126
		if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
			if i+len(n) < end {
127
				tool = &t
128
129
130
				end = i + len(n)
			}
		}
131
132
	}

133
	if tool == nil {
134
		return nil
135
136
	}

137
138
139
	// only look for arguments after the tool name if the tool has parameters
	// TODO (jmorganca): while probably uncommon, this doesn't support
	// parsing arguments before the tool name, which may be needed in the future
140
	args := map[string]any{}
141
	if len(tool.Function.Parameters.Properties) > 0 {
142
		if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
143
144
			return nil
		}
145

146
		end += i
147
148
	}

149
150
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
151
			Name:      tool.Function.Name,
152
153
154
			Arguments: args,
			Index:     p.n,
		},
155
156
	}

157
158
159
160
161
162
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

// findArguments returns the first object that appears to be
163
164
165
166
167
168
169
170
// arguments for the provided tool in the provided buffer,
// returning nil if no arguments are found.
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
	if len(buffer) == 0 {
171
172
173
		return nil, 0
	}

174
175
176
177
178
179
	var braces int
	var start int = -1
	var end int
	var object []byte

	// find any outer json object
180
	for i, c := range buffer {
181
182
183
184
185
		if c == '{' {
			braces++
			if start == -1 {
				start = i
			}
186
		}
187
188

		if c == '}' {
189
190
191
192
			if start != -1 {
				braces--
				if braces == 0 {
					end = i + 1
193
					object = buffer[start:end]
194
195
					break
				}
196
			}
197
		}
198
199
200
201
202
203
204
205
206
207
208
209
210
	}

	if braces > 0 {
		return nil, 0
	}

	var data map[string]any
	if err := json.Unmarshal(object, &data); err != nil {
		return nil, 0
	}

	var find func(obj any) map[string]any
	find = func(obj any) map[string]any {
211
		switch obj := obj.(type) {
212
		case map[string]any:
213
214
			valid := true
			// check if all keys in the object exist in the tool's parameters
215
216
			for key := range obj {
				if _, exists := tool.Function.Parameters.Properties[key]; !exists {
217
					valid = false
218
					break
219
220
221
				}
			}

222
223
224
225
226
227
228
229
230
231
232
233
			// check for required parameters
			// TODO (jmorganca): this should error instead of silently failing
			if valid {
				for _, required := range tool.Function.Parameters.Required {
					if _, exists := obj[required]; !exists {
						valid = false
						break
					}
				}
			}

			if valid {
234
235
236
237
				return obj
			}

			for _, value := range obj {
238
239
240
241
242
				if result := find(value); result != nil {
					return result
				}
			}
		case []any:
243
			for _, item := range obj {
244
245
246
247
				if result := find(item); result != nil {
					return result
				}
			}
248
		}
249
250

		return nil
251
252
	}

253
254
255
	result := find(data)
	if result != nil {
		return result, end
256
257
	}

258
	return nil, 0
259
260
}

261
262
263
264
265
266
267
268
269
270
271
272
273
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
274
275
	}

276
277
278
279
280
281
282
283
284
285
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
286
287
	}

288
289
290
291
292
293
294
295
296
297
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
298

299
300
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
301
302
	}

303
	return ""
304
}