tools.go 5.56 KB
Newer Older
1
2
3
package tools

import (
4
	"bytes"
5
6
	"encoding/json"
	"strings"
7
	"text/template"
8
9
10
11

	"github.com/ollama/ollama/api"
)

12
13
14
15
16
17
type toolsState int

const (
	toolsState_LookingForTag toolsState = iota
	toolsState_ToolCalling
	toolsState_Done
18
19
20
)

type Parser struct {
21
22
23
24
25
26
27
	tag        string
	names      []string
	properties []string

	state  toolsState
	buffer []byte
	n      int
28
29
}

30
31
32
33
34
// NewParser creates a new tool call parser from a model's chat
// template and a list of provided tools.
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
	return NewParserWithTag(tools, parseTag(tmpl))
}
35

36
37
38
39
40
41
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
	var p Parser
	for _, t := range tools {
		p.names = append(p.names, t.Function.Name)
		for r := range t.Function.Parameters.Properties {
			p.properties = append(p.properties, r)
42
43
		}
	}
44
45
46
	p.tag = tag
	return &p
}
47

48
49
50
51
52
// Add processes a string input to parse tool calls and content that
// should be sent back to the user.
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
	if p.state == toolsState_Done {
		return nil, s
53
54
	}

55
	p.buffer = append(p.buffer, s...)
56

57
58
59
60
61
62
63
64
	if p.state == toolsState_LookingForTag {
		i, found := p.findTag()
		if i == -1 {
			content = string(p.buffer)
			p.buffer = []byte{}
		} else {
			content = string(p.buffer[:i])
			p.buffer = p.buffer[i:]
65
66
		}

67
68
69
70
71
72
73
74
		// for models where { or [ are used as tool calling
		// tags, we only support parsing tools if the first non-
		// whitespace character is { or [
		if p.tag == "{" || p.tag == "[" {
			if strings.TrimSpace(content) != "" {
				p.state = toolsState_Done
				return nil, content + string(p.buffer)
			}
75
76
		}

77
78
		if !found {
			return nil, content
79
		}
80
81

		p.state = toolsState_ToolCalling
82
83
	}

84
85
86
87
88
89
90
	for {
		call := p.parseToolCall()
		if call == nil {
			break
		}

		calls = append(calls, *call)
91
92
	}

93
94
95
96
97
98
99
	if p.done() {
		p.state = toolsState_Done
		content = string(p.buffer)
		p.buffer = []byte{}
	}

	return calls, content
100
101
}

102
103
104
105
106
107
108
// findTag searches the buffer to find and handle a tool calling tag
// returning true if the tag was found and false otherwise, and
// a string content signaling any content that should be sent back to the user
func (p *Parser) findTag() (int, bool) {
	// First check for complete substring anywhere in s
	if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
		return i, true
109
110
	}

111
112
113
114
115
116
	// Then check for partial suffix overlap
	max := min(len(p.buffer), len(p.tag))
	for i := max; i > 0; i-- {
		if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
			return len(p.buffer) - i, false
		}
117
	}
118
119
	return -1, false
}
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
	var name string
	var args map[string]any
	var end int = len(p.buffer)

	// find tool name
	var i int
	for _, n := range p.names {
		if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
			if i+len(n) < end {
				name = n
				end = i + len(n)
			}
		}
137
138
	}

139
140
	if name == "" {
		return nil
141
142
	}

143
144
145
	if args, i = p.findArguments(); args == nil {
		return nil
	}
146

147
148
	if i > end {
		end = i
149
150
	}

151
152
153
154
155
156
	tc := &api.ToolCall{
		Function: api.ToolCallFunction{
			Name:      name,
			Arguments: args,
			Index:     p.n,
		},
157
158
	}

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
	p.n++
	p.buffer = p.buffer[end:]
	return tc
}

// findArguments returns the first object that appears to be
// arguments and the position where the arguments end, returning nil and 0 if
// an invalid JSON object or non-arguments object is found first
func (p *Parser) findArguments() (map[string]any, int) {
	if len(p.buffer) == 0 {
		return nil, 0
	}

	var braces int
	var start int = -1
	var end int
	var object []byte

	// find any outer json object
	for i, c := range p.buffer {
		if c == '{' {
			braces++
			if start == -1 {
				start = i
			}
184
		}
185
186
187
188
189
190
191
192

		if c == '}' {
			braces--
			if braces == 0 && start != -1 {
				end = i + 1
				object = p.buffer[start:end]
				break
			}
193
		}
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
	}

	if braces > 0 {
		return nil, 0
	}

	var data map[string]any

	// not valid json
	if err := json.Unmarshal(object, &data); err != nil {
		return nil, 0
	}

	var find func(obj any) map[string]any
	find = func(obj any) map[string]any {
		switch v := obj.(type) {
		case map[string]any:
			// check if the object keys are valid tool properties
			// TODO (jmorganca): check only sets of properties that
			// go together instead of the entire set
			for _, prop := range p.properties {
				if _, exists := v[prop]; exists {
					return v
				}
			}

			for _, value := range v {
				if result := find(value); result != nil {
					return result
				}
			}
		case []any:
			for _, item := range v {
				if result := find(item); result != nil {
					return result
				}
			}
231
		}
232
233

		return nil
234
235
	}

236
237
238
	result := find(data)
	if result != nil {
		return result, end
239
240
	}

241
	return nil, 0
242
243
}

244
245
246
247
248
249
250
251
252
253
254
255
256
// done checks if the parser is done parsing by looking
// for closing tag. currently only } and ] are supported
// for closing tags as {} or [] pairs may not always
// represent tool calls and we need to send the content back
func (p *Parser) done() bool {
	var open, close rune
	switch p.tag {
	case "{":
		open, close = '{', '}'
	case "[":
		open, close = '[', ']'
	default:
		return false
257
258
	}

259
260
261
262
263
264
265
266
267
268
	var count int
	for _, c := range p.buffer {
		if c == byte(open) {
			count++
		} else if c == byte(close) {
			count--
			if count == 0 {
				return true
			}
		}
269
270
	}

271
272
273
274
275
276
277
278
279
280
	return false
}

// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
	if p.n > 0 {
		return ""
	}
281

282
283
	if p.tag == "{" || p.tag == "[" {
		return string(p.buffer)
284
285
	}

286
	return ""
287
}