tools.go 7.12 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

Michael Yang's avatar
Michael Yang committed
29
30
31
32
func (p *Parser) GetBuffer() []byte {
	return p.buffer
}

33
34
35
36
37
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
38

39
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
40
41
42
	return &Parser{
		tag:   tag,
		tools: tools,
43
	}
44
}
45

46
47
48
49
50
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
51
52
	}

53
	p.buffer = append(p.buffer, s...)
54

55
56
57
58
59
60
61
62
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
63
64
		}

65
66
67
68
69
70
71
72
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
73
74
		}

75
76
		if !found {
			return nil, content
77
		}
78
79

		p.state = toolsState_ToolCalling
80
81
	}

82
83
84
85
86
87
88
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
89
90
	}

91
92
93
94
95
96
97
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
98
99
}

100
101
102
103
104
105
106
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
107
108
	}

109
110
111
112
113
114
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
115
	}
116
117
	return -1, false
}
118

119
120
121
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
122
	tool, end := findTool(p.tools, p.buffer)
123
	if tool == nil {
124
		return nil
125
126
	}

127
128
129
130
131
132
133
	var args map[string]any
	if found, i := findArguments(p.buffer); found == nil {
		return nil
	} else {
		args = found
		if i > end {
			end = i
134
		}
135
136
	}

137
138
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
139
			Name:      tool.Function.Name,
140
141
142
			Arguments: args,
			Index:     p.n,
		},
143
144
	}

145
146
147
148
149
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

215
// findArguments returns the first object that appears to be
216
// arguments for the provided tool in the provided buffer,
217
// returning nil if no arguments are found and the end position
218
219
220
221
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
222
func findArguments(buffer []byte) (map[string]any, int) {
223
	if len(buffer) == 0 {
224
225
226
		return nil, 0
	}

227
228
229
	var braces int
	var start int = -1

230
	for i, c := range buffer {
231
		if c == '{' {
232
			if braces == 0 {
233
234
				start = i
			}
235
236
237
238
239
240
241
242
243
244
			braces++
		} else if c == '}' && braces > 0 {
			braces--
			if braces == 0 && start != -1 {
				object := buffer[start : i+1]

				var data map[string]any
				if err := json.Unmarshal(object, &data); err != nil {
					start = -1
					continue
245
				}
246

247
248
249
250
251
252
253
254
255
256
257
				var findObject func(obj map[string]any) (map[string]any, bool)
				findObject = func(obj map[string]any) (map[string]any, bool) {
					if _, hasName := obj["name"]; hasName {
						if args, ok := obj["arguments"].(map[string]any); ok {
							return args, true
						}
						if args, ok := obj["parameters"].(map[string]any); ok {
							return args, true
						}
						return nil, true
					}
258

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
					for _, v := range obj {
						switch child := v.(type) {
						case map[string]any:
							if result, found := findObject(child); found {
								return result, true
							}
						case []any:
							for _, item := range child {
								if childObj, ok := item.(map[string]any); ok {
									if result, found := findObject(childObj); found {
										return result, true
									}
								}
							}
						}
					}
275

276
					return nil, false
277
278
				}

279
280
				if args, found := findObject(data); found {
					return args, i
281
282
				}

283
				return data, i
284
			}
285
286
287
		}
	}

288
	return nil, 0
289
290
}

291
292
293
294
295
296
297
298
299
300
301
302
303
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
304
305
	}

306
307
308
309
310
311
312
313
314
315
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
316
317
	}

318
319
320
321
322
323
324
325
326
327
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
328

329
330
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
331
332
	}

333
	return ""
334
}