tools.go 8 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

Michael Yang's avatar
Michael Yang committed
29
30
31
32
func (p *Parser) GetBuffer() []byte {
	return p.buffer
}

33
34
35
36
37
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
38

39
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
40
41
42
	return &Parser{
		tag:   tag,
		tools: tools,
43
	}
44
}
45

46
47
48
49
50
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
51
52
	}

53
	p.buffer = append(p.buffer, s...)
54

55
56
57
58
59
60
61
62
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
63
64
		}

65
66
67
68
69
70
71
72
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
73
74
		}

75
76
		if !found {
			return nil, content
77
		}
78
79

		p.state = toolsState_ToolCalling
80
81
	}

82
83
84
85
86
87
88
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
89
90
	}

91
92
93
94
95
96
97
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
98
99
}

100
101
102
103
104
105
106
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
107
108
	}

109
110
111
112
113
114
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
115
	}
116
117
	return -1, false
}
118

119
120
121
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
122
	tool, end := findTool(p.tools, p.buffer)
123
	if tool == nil {
124
		return nil
125
126
	}

127
	var argsMap map[string]any
128
	if found, i := findArguments(tool, p.buffer); found == nil {
129
130
		return nil
	} else {
131
		argsMap = found
132
133
		if i > end {
			end = i
134
		}
135
136
	}

137
138
139
140
141
	args := api.NewToolCallFunctionArguments()
	for k, v := range argsMap {
		args.Set(k, v)
	}

142
143
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
144
			Name:      tool.Function.Name,
145
146
147
			Arguments: args,
			Index:     p.n,
		},
148
149
	}

150
151
152
153
154
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

220
// findArguments returns the first object that appears to be
221
// arguments for the provided tool in the provided buffer,
222
// returning nil if no arguments are found and the end position
223
224
225
226
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
227
func findArguments(tool *api.Tool, buffer []byte) (map[string]any, int) {
228
	if len(buffer) == 0 {
229
230
231
		return nil, 0
	}

232
	start := -1
233
	var braces int
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
	var inString, escaped bool

	for i := range buffer {
		c := buffer[i]

		if escaped {
			escaped = false
			continue
		}

		if c == '\\' {
			escaped = true
			continue
		}

		if c == '"' {
			inString = !inString
			continue
		}

		if inString {
			continue
		}
257
258

		if c == '{' {
259
			if braces == 0 {
260
261
				start = i
			}
262
			braces++
263
		} else if c == '}' {
264
265
266
267
268
269
			braces--
			if braces == 0 && start != -1 {
				object := buffer[start : i+1]

				var data map[string]any
				if err := json.Unmarshal(object, &data); err != nil {
270
					// not a valid object, keep looking
271
272
					start = -1
					continue
273
				}
274

275
276
				var findObject func(obj map[string]any) (map[string]any, bool)
				findObject = func(obj map[string]any) (map[string]any, bool) {
277
278
					findMap := func(name string, obj map[string]any) (map[string]any, bool) {
						if args, ok := obj[name].(map[string]any); ok {
279
280
							return args, true
						}
281
						if argsStr, ok := obj[name].(string); ok {
282
283
284
285
286
							var argsData map[string]interface{}
							if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
								return argsData, ok
							}
						}
287
288
289
290
						return nil, false
					}
					if _, hasName := obj["name"]; hasName {
						if args, ok := findMap("arguments", obj); ok {
291
292
							return args, true
						}
293
294
						if args, ok := findMap("parameters", obj); ok {
							return args, true
295
						}
296
297
						return nil, true
					}
298
299
300
					if args, ok := findMap(tool.Function.Name, obj); ok {
						return args, true
					}
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
					for _, v := range obj {
						switch child := v.(type) {
						case map[string]any:
							if result, found := findObject(child); found {
								return result, true
							}
						case []any:
							for _, item := range child {
								if childObj, ok := item.(map[string]any); ok {
									if result, found := findObject(childObj); found {
										return result, true
									}
								}
							}
						}
					}
318

319
					return nil, false
320
321
				}

322
323
				if args, found := findObject(data); found {
					return args, i
324
325
				}

326
				return data, i
327
			}
328
329
330
331

			if braces < 0 {
				braces = 0
			}
332
333
334
		}
	}

335
	return nil, 0
336
337
}

338
339
340
341
342
343
344
345
346
347
348
349
350
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
351
352
	}

353
354
355
356
357
358
359
360
361
362
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
363
364
	}

365
366
367
368
369
370
371
372
373
374
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
375

376
377
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
378
379
	}

380
	return ""
381
}