tools.go 7.07 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

29
30
31
32
33
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
34

35
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
36
37
38
	return &Parser{
		tag:   tag,
		tools: tools,
39
	}
40
}
41

42
43
44
45
46
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
47
48
	}

49
	p.buffer = append(p.buffer, s...)
50

51
52
53
54
55
56
57
58
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
59
60
		}

61
62
63
64
65
66
67
68
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
69
70
		}

71
72
		if !found {
			return nil, content
73
		}
74
75

		p.state = toolsState_ToolCalling
76
77
	}

78
79
80
81
82
83
84
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
85
86
	}

87
88
89
90
91
92
93
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
94
95
}

96
97
98
99
100
101
102
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
103
104
	}

105
106
107
108
109
110
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
111
	}
112
113
	return -1, false
}
114

115
116
117
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
118
	tool, end := findTool(p.tools, p.buffer)
119
	if tool == nil {
120
		return nil
121
122
	}

123
124
125
126
127
128
129
	var args map[string]any
	if found, i := findArguments(p.buffer); found == nil {
		return nil
	} else {
		args = found
		if i > end {
			end = i
130
		}
131
132
	}

133
134
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
135
			Name:      tool.Function.Name,
136
137
138
			Arguments: args,
			Index:     p.n,
		},
139
140
	}

141
142
143
144
145
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// findTool finds the first tool name in the list that matches the
// beginning of the buffer, returning nil if no tool is found
// or if the buffer ends with a partial tool name since we need
// to wait for more data to disambiguate.
// The second return value is the end position of the tool name
// if one is found, otherwise 0.
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
	if len(buf) == 0 {
		return nil, 0
	}

	// check if buffer ends with a partial tool name
	// this prevents matching "get" when seeing "get_weather"
	var longest string
	for _, t := range tools {
		if len(t.Function.Name) > len(longest) {
			longest = t.Function.Name
		}
	}

	// Only check up to longest characters from the end
	for i := 1; i <= min(len(buf), len(longest)); i++ {
		tail := buf[len(buf)-i:]
		for _, t := range tools {
			name := []byte(t.Function.Name)
			if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
				return nil, 0
			}
		}
	}

	// find first occurrence of the longest tool name
	var found *api.Tool
	start := -1
	end := -1

	for i := range tools {
		name := []byte(tools[i].Function.Name)
		pos := bytes.Index(buf, name)
		if pos == -1 {
			continue
		}

		// Skip if we have a better match already
		if start != -1 {
			if pos > start {
				continue
			}
			if pos == start && len(name) <= len(found.Function.Name) {
				continue
			}
		}

		found = &tools[i]
		start = pos
		end = pos + len(name)
	}

	if found != nil {
		return found, end
	}

	return nil, 0
}

211
// findArguments returns the first object that appears to be
212
// arguments for the provided tool in the provided buffer,
213
// returning nil if no arguments are found and the end position
214
215
216
217
// TODO (jmorganca): this does not support parsing omitted arguments
// objects for functions that have all-optional parameters
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
// `{"name": "get_conditions"}` will not currently work
218
func findArguments(buffer []byte) (map[string]any, int) {
219
	if len(buffer) == 0 {
220
221
222
		return nil, 0
	}

223
224
225
	var braces int
	var start int = -1

226
	for i, c := range buffer {
227
		if c == '{' {
228
			if braces == 0 {
229
230
				start = i
			}
231
232
233
234
235
236
237
238
239
240
			braces++
		} else if c == '}' && braces > 0 {
			braces--
			if braces == 0 && start != -1 {
				object := buffer[start : i+1]

				var data map[string]any
				if err := json.Unmarshal(object, &data); err != nil {
					start = -1
					continue
241
				}
242

243
244
245
246
247
248
249
250
251
252
253
				var findObject func(obj map[string]any) (map[string]any, bool)
				findObject = func(obj map[string]any) (map[string]any, bool) {
					if _, hasName := obj["name"]; hasName {
						if args, ok := obj["arguments"].(map[string]any); ok {
							return args, true
						}
						if args, ok := obj["parameters"].(map[string]any); ok {
							return args, true
						}
						return nil, true
					}
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
					for _, v := range obj {
						switch child := v.(type) {
						case map[string]any:
							if result, found := findObject(child); found {
								return result, true
							}
						case []any:
							for _, item := range child {
								if childObj, ok := item.(map[string]any); ok {
									if result, found := findObject(childObj); found {
										return result, true
									}
								}
							}
						}
					}
271

272
					return nil, false
273
274
				}

275
276
				if args, found := findObject(data); found {
					return args, i
277
278
				}

279
				return data, i
280
			}
281
282
283
		}
	}

284
	return nil, 0
285
286
}

287
288
289
290
291
292
293
294
295
296
297
298
299
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
300
301
	}

302
303
304
305
306
307
308
309
310
311
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
312
313
	}

314
315
316
317
318
319
320
321
322
323
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
324

325
326
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
327
328
	}

329
	return ""
330
}