tools.go 7.47 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

29
30
31
32
33
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
34

35
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
36
37
38
	return &Parser{
		tag:   tag,
		tools: tools,
39
	}
40
}
41

42
43
44
45
46
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
47
48
	}

49
	p.buffer = append(p.buffer, s...)
50

51
52
53
54
55
56
57
58
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
59
60
		}

61
62
63
64
65
66
67
68
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
69
70
		}

71
72
		if !found {
			return nil, content
73
		}
74
75

		p.state = toolsState_ToolCalling
76
77
	}

78
79
80
81
82
83
84
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
85
86
	}

87
88
89
90
91
92
93
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
94
95
}

96
97
98
99
100
101
102
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
103
104
	}

105
106
107
108
109
110
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
111
	}
112
113
	return -1, false
}
114

115
116
117
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
118
	tool, end := findTool(p.tools, p.buffer)
119
	if tool == nil {
120
		return nil
121
122
	}

123
124
125
	// only look for arguments after the tool name if the tool has parameters
	// TODO (jmorganca): while probably uncommon, this doesn't support
	// parsing arguments before the tool name, which may be needed in the future
126
	args := map[string]any{}
127
	if len(tool.Function.Parameters.Properties) > 0 {
128
		var i int
129
		if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
130
131
			return nil
		}
132
		end += i
133
134
	}

135
136
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
137
			Name:      tool.Function.Name,
138
139
140
			Arguments: args,
			Index:     p.n,
		},
141
142
	}

143
144
145
146
147
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

213
// findArguments returns the first object that appears to be
214
// arguments for the provided tool in the provided buffer,
215
// returning nil if no arguments are found and the end position
216
217
218
219
220
221
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
	if len(buffer) == 0 {
222
223
224
		return nil, 0
	}

225
226
227
228
229
230
	var braces int
	var start int = -1
	var end int
	var object []byte

	// find any outer json object
231
	for i, c := range buffer {
232
233
234
235
236
		if c == '{' {
			braces++
			if start == -1 {
				start = i
			}
237
		}
238
239

		if c == '}' {
240
241
242
243
			if start != -1 {
				braces--
				if braces == 0 {
					end = i + 1
244
					object = buffer[start:end]
245
246
					break
				}
247
			}
248
		}
249
250
251
252
253
254
255
256
257
258
259
260
261
	}

	if braces > 0 {
		return nil, 0
	}

	var data map[string]any
	if err := json.Unmarshal(object, &data); err != nil {
		return nil, 0
	}

	var find func(obj any) map[string]any
	find = func(obj any) map[string]any {
262
		switch obj := obj.(type) {
263
		case map[string]any:
264
265
			valid := true
			// check if all keys in the object exist in the tool's parameters
266
267
			for key := range obj {
				if _, exists := tool.Function.Parameters.Properties[key]; !exists {
268
					valid = false
269
					break
270
271
272
				}
			}

273
274
275
276
277
278
279
280
281
282
283
284
			// check for required parameters
			// TODO (jmorganca): this should error instead of silently failing
			if valid {
				for _, required := range tool.Function.Parameters.Required {
					if _, exists := obj[required]; !exists {
						valid = false
						break
					}
				}
			}

			if valid {
285
286
287
288
				return obj
			}

			for _, value := range obj {
289
290
291
292
293
				if result := find(value); result != nil {
					return result
				}
			}
		case []any:
294
			for _, item := range obj {
295
296
297
298
				if result := find(item); result != nil {
					return result
				}
			}
299
		}
300
301

		return nil
302
303
	}

304
305
306
	result := find(data)
	if result != nil {
		return result, end
307
308
	}

309
	return nil, 0
310
311
}

312
313
314
315
316
317
318
319
320
321
322
323
324
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
325
326
	}

327
328
329
330
331
332
333
334
335
336
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
337
338
	}

339
340
341
342
343
344
345
346
347
348
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
349

350
351
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
352
353
	}

354
	return ""
355
}