tools.go 7.84 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

Michael Yang's avatar
Michael Yang committed
29
30
31
32
func (p *Parser) GetBuffer() []byte {
	return p.buffer
}

33
34
35
36
37
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
38

39
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
40
41
42
	return &Parser{
		tag:   tag,
		tools: tools,
43
	}
44
}
45

46
47
48
49
50
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
51
52
	}

53
	p.buffer = append(p.buffer, s...)
54

55
56
57
58
59
60
61
62
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
63
64
		}

65
66
67
68
69
70
71
72
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
73
74
		}

75
76
		if !found {
			return nil, content
77
		}
78
79

		p.state = toolsState_ToolCalling
80
81
	}

82
83
84
85
86
87
88
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
89
90
	}

91
92
93
94
95
96
97
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
98
99
}

100
101
102
103
104
105
106
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
107
108
	}

109
110
111
112
113
114
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
115
	}
116
117
	return -1, false
}
118

119
120
121
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
122
	tool, end := findTool(p.tools, p.buffer)
123
	if tool == nil {
124
		return nil
125
126
	}

127
128
129
130
131
132
133
	var args map[string]any
	if found, i := findArguments(p.buffer); found == nil {
		return nil
	} else {
		args = found
		if i > end {
			end = i
134
		}
135
136
	}

137
138
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
139
			Name:      tool.Function.Name,
140
141
142
			Arguments: args,
			Index:     p.n,
		},
143
144
	}

145
146
147
148
149
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

215
// findArguments returns the first object that appears to be
216
// arguments for the provided tool in the provided buffer,
217
// returning nil if no arguments are found and the end position
218
219
220
221
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
222
func findArguments(buffer []byte) (map[string]any, int) {
223
	if len(buffer) == 0 {
224
225
226
		return nil, 0
	}

227
	start := -1
228
	var braces int
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
	var inString, escaped bool

	for i := range buffer {
		c := buffer[i]

		if escaped {
			escaped = false
			continue
		}

		if c == '\\' {
			escaped = true
			continue
		}

		if c == '"' {
			inString = !inString
			continue
		}

		if inString {
			continue
		}
252
253

		if c == '{' {
254
			if braces == 0 {
255
256
				start = i
			}
257
			braces++
258
		} else if c == '}' {
259
260
261
262
263
264
			braces--
			if braces == 0 && start != -1 {
				object := buffer[start : i+1]

				var data map[string]any
				if err := json.Unmarshal(object, &data); err != nil {
265
					// not a valid object, keep looking
266
267
					start = -1
					continue
268
				}
269

270
271
272
273
274
275
				var findObject func(obj map[string]any) (map[string]any, bool)
				findObject = func(obj map[string]any) (map[string]any, bool) {
					if _, hasName := obj["name"]; hasName {
						if args, ok := obj["arguments"].(map[string]any); ok {
							return args, true
						}
276
277
278
279
280
281
						if argsStr, ok := obj["arguments"].(string); ok {
							var argsData map[string]interface{}
							if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
								return argsData, ok
							}
						}
282
283
284
						if args, ok := obj["parameters"].(map[string]any); ok {
							return args, true
						}
285
286
287
288
289
290
						if argsStr, ok := obj["parameters"].(string); ok {
							var argsData map[string]interface{}
							if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
								return argsData, ok
							}
						}
291
292
						return nil, true
					}
293

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
					for _, v := range obj {
						switch child := v.(type) {
						case map[string]any:
							if result, found := findObject(child); found {
								return result, true
							}
						case []any:
							for _, item := range child {
								if childObj, ok := item.(map[string]any); ok {
									if result, found := findObject(childObj); found {
										return result, true
									}
								}
							}
						}
					}
310

311
					return nil, false
312
313
				}

314
315
				if args, found := findObject(data); found {
					return args, i
316
317
				}

318
				return data, i
319
			}
320
321
322
323

			if braces < 0 {
				braces = 0
			}
324
325
326
		}
	}

327
	return nil, 0
328
329
}

330
331
332
333
334
335
336
337
338
339
340
341
342
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
343
344
	}

345
346
347
348
349
350
351
352
353
354
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
355
356
	}

357
358
359
360
361
362
363
364
365
366
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
367

368
369
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
370
371
	}

372
	return ""
373
}