tools.go 7.41 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

Michael Yang's avatar
Michael Yang committed
29
30
31
32
func (p *Parser) GetBuffer() []byte {
	return p.buffer
}

33
34
35
36
37
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
38

39
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
40
41
42
	return &Parser{
		tag:   tag,
		tools: tools,
43
	}
44
}
45

46
47
48
49
50
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
51
52
	}

53
	p.buffer = append(p.buffer, s...)
54

55
56
57
58
59
60
61
62
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
63
64
		}

65
66
67
68
69
70
71
72
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
73
74
		}

75
76
		if !found {
			return nil, content
77
		}
78
79

		p.state = toolsState_ToolCalling
80
81
	}

82
83
84
85
86
87
88
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
89
90
	}

91
92
93
94
95
96
97
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
98
99
}

100
101
102
103
104
105
106
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
107
108
	}

109
110
111
112
113
114
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
115
	}
116
117
	return -1, false
}
118

119
120
121
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
122
	tool, end := findTool(p.tools, p.buffer)
123
	if tool == nil {
124
		return nil
125
126
	}

127
128
129
130
131
132
133
	var args map[string]any
	if found, i := findArguments(p.buffer); found == nil {
		return nil
	} else {
		args = found
		if i > end {
			end = i
134
		}
135
136
	}

137
138
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
139
			Name:      tool.Function.Name,
140
141
142
			Arguments: args,
			Index:     p.n,
		},
143
144
	}

145
146
147
148
149
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

215
// findArguments returns the first object that appears to be
216
// arguments for the provided tool in the provided buffer,
217
// returning nil if no arguments are found and the end position
218
219
220
221
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
222
func findArguments(buffer []byte) (map[string]any, int) {
223
	if len(buffer) == 0 {
224
225
226
		return nil, 0
	}

227
	start := -1
228
	var braces int
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
	var inString, escaped bool

	for i := range buffer {
		c := buffer[i]

		if escaped {
			escaped = false
			continue
		}

		if c == '\\' {
			escaped = true
			continue
		}

		if c == '"' {
			inString = !inString
			continue
		}

		if inString {
			continue
		}
252
253

		if c == '{' {
254
			if braces == 0 {
255
256
				start = i
			}
257
			braces++
258
		} else if c == '}' {
259
260
261
262
263
264
			braces--
			if braces == 0 && start != -1 {
				object := buffer[start : i+1]

				var data map[string]any
				if err := json.Unmarshal(object, &data); err != nil {
265
					// not a valid object, keep looking
266
267
					start = -1
					continue
268
				}
269

270
271
272
273
274
275
276
277
278
279
280
				var findObject func(obj map[string]any) (map[string]any, bool)
				findObject = func(obj map[string]any) (map[string]any, bool) {
					if _, hasName := obj["name"]; hasName {
						if args, ok := obj["arguments"].(map[string]any); ok {
							return args, true
						}
						if args, ok := obj["parameters"].(map[string]any); ok {
							return args, true
						}
						return nil, true
					}
281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
					for _, v := range obj {
						switch child := v.(type) {
						case map[string]any:
							if result, found := findObject(child); found {
								return result, true
							}
						case []any:
							for _, item := range child {
								if childObj, ok := item.(map[string]any); ok {
									if result, found := findObject(childObj); found {
										return result, true
									}
								}
							}
						}
					}
298

299
					return nil, false
300
301
				}

302
303
				if args, found := findObject(data); found {
					return args, i
304
305
				}

306
				return data, i
307
			}
308
309
310
311

			if braces < 0 {
				braces = 0
			}
312
313
314
		}
	}

315
	return nil, 0
316
317
}

318
319
320
321
322
323
324
325
326
327
328
329
330
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
331
332
	}

333
334
335
336
337
338
339
340
341
342
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
343
344
	}

345
346
347
348
349
350
351
352
353
354
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
355

356
357
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
358
359
	}

360
	return ""
361
}