tools.go 5.51 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
	tag   string
	tools []api.Tool
23
24
25
26

	state  toolsState
	buffer []byte
	n      int
27
28
}

29
30
31
32
33
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
34

35
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
36
37
38
	return &Parser{
		tag:   tag,
		tools: tools,
39
	}
40
}
41

42
43
44
45
46
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
47
48
	}

49
	p.buffer = append(p.buffer, s...)
50

51
52
53
54
55
56
57
58
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
59
60
		}

61
62
63
64
65
66
67
68
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
69
70
		}

71
72
		if !found {
			return nil, content
73
		}
74
75

		p.state = toolsState_ToolCalling
76
77
	}

78
79
80
81
82
83
84
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
85
86
	}

87
88
89
90
91
92
93
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
94
95
}

96
97
98
99
100
101
102
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
103
104
	}

105
106
107
108
109
110
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
111
	}
112
113
	return -1, false
}
114

115
116
117
118
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
	var args map[string]any
119
	var tool *api.Tool
120
121
122
	var end int = len(p.buffer)

	var i int
123
124
125
	// find tool name
	for _, t := range p.tools {
		n := t.Function.Name
126
127
		if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
			if i+len(n) < end {
128
				tool = &t
129
130
131
				end = i + len(n)
			}
		}
132
133
	}

134
	if tool == nil {
135
		return nil
136
137
	}

138
139
140
141
142
	// only look for arguments if the tool has parameters
	if len(tool.Function.Parameters.Properties) > 0 {
		if args, i = p.findArguments(*tool); args == nil {
			return nil
		}
143

144
145
146
		if i > end {
			end = i
		}
147
148
	}

149
150
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
151
			Name:      tool.Function.Name,
152
153
154
			Arguments: args,
			Index:     p.n,
		},
155
156
	}

157
158
159
160
161
162
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

// findArguments returns the first object that appears to be
163
164
// arguments for the provided tool, returning nil
func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
165
166
167
168
	if len(p.buffer) == 0 {
		return nil, 0
	}

169
170
171
172
173
	// no arguments to parse
	if len(tool.Function.Parameters.Properties) == 0 {
		return nil, 0
	}

174
175
176
177
178
179
180
181
182
183
184
185
	var braces int
	var start int = -1
	var end int
	var object []byte

	// find any outer json object
	for i, c := range p.buffer {
		if c == '{' {
			braces++
			if start == -1 {
				start = i
			}
186
		}
187
188

		if c == '}' {
189
190
191
192
193
194
195
			if start != -1 {
				braces--
				if braces == 0 {
					end = i + 1
					object = p.buffer[start:end]
					break
				}
196
			}
197
		}
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
	}

	if braces > 0 {
		return nil, 0
	}

	var data map[string]any

	// not valid json
	if err := json.Unmarshal(object, &data); err != nil {
		return nil, 0
	}

	var find func(obj any) map[string]any
	find = func(obj any) map[string]any {
213
		switch obj := obj.(type) {
214
		case map[string]any:
215
216
217
218
219
			found := true
			for key := range obj {
				if _, exists := tool.Function.Parameters.Properties[key]; !exists {
					found = false
					break
220
221
222
				}
			}

223
224
225
226
227
			if found {
				return obj
			}

			for _, value := range obj {
228
229
230
231
232
				if result := find(value); result != nil {
					return result
				}
			}
		case []any:
233
			for _, item := range obj {
234
235
236
237
				if result := find(item); result != nil {
					return result
				}
			}
238
		}
239
240

		return nil
241
242
	}

243
244
245
	result := find(data)
	if result != nil {
		return result, end
246
247
	}

248
	return nil, 0
249
250
}

251
252
253
254
255
256
257
258
259
260
261
262
263
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
264
265
	}

266
267
268
269
270
271
272
273
274
275
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
276
277
	}

278
279
280
281
282
283
284
285
286
287
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
288

289
290
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
291
292
	}

293
	return ""
294
}