tools.go 6.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package tools

import (
	"encoding/json"
	"errors"
	"log/slog"
	"strings"
	gotmpl "text/template"

	"github.com/ollama/ollama/api"
	"github.com/ollama/ollama/template"
)

var (
	errInvalidToolCall = errors.New("invalid tool call format")
	errAccumulateMore  = errors.New("need to accumulate more content")
)

type Parser struct {
	parseLeadingJSON bool
	prefix           string
	prefixFound      bool
	tmpl             gotmpl.Template
	sb               strings.Builder
	index            int
	name             string
	arguments        string
	done             bool
}

// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
//
// Parameters:
//   - s: The string to parse
//   - name: The field name from template that identifies the tool call name
//   - arguments: The field name from template that identifies the tool call arguments
//
// Returns:
//   - []api.ToolCall: The parsed tool calls if successful
//   - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
	// Check for balanced braces before attempting to parse
	braceCount := 0
	squareCount := 0
	startIndex := -1
	var rawToolCalls []string
	s = strings.TrimSpace(s)

	// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
	trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
	for i, c := range s {
		switch c {
		case '{':
			braceCount++
			if startIndex == -1 {
				startIndex = i
			}
		case '}':
			braceCount--
			if braceCount == 0 {
				rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
				startIndex = -1
			}
		case '[':
			if trackSquareBrackets {
				squareCount++
			}
		case ']':
			if trackSquareBrackets {
				squareCount--
			}
		}

		// Negative means we have an extra closing brace/bracket
		if braceCount < 0 || squareCount < 0 {
			return nil, errInvalidToolCall
		}
	}

	// If braces/brackets aren't balanced, need more input
	if braceCount > 0 || squareCount > 0 {
		return nil, errAccumulateMore
	}

	t := strings.TrimSpace(s)
	if len(t) == 0 {
		return nil, errAccumulateMore
	}
	// If the input is a single square bracket, it's not a valid tool call
	if t[0] == '[' && len(t) == 1 {
		return nil, errAccumulateMore
	}

	// Attempt full unmarshal of the JSON
	var toolCalls []api.ToolCall
	for _, rawToolCall := range rawToolCalls {
		var resp map[string]any
		if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
			continue
		}

		// Collect nested objects that could contain tool calls
		objs := collect(resp)
		if len(objs) == 0 {
			continue
		}

		// Extract tool calls from objects
		for _, kv := range objs {
			n, nok := kv[name].(string)
			a, aok := kv[arguments].(map[string]any)
			if nok && aok {
				toolCalls = append(toolCalls, api.ToolCall{
					Function: api.ToolCallFunction{
						Name:      n,
						Arguments: a,
					},
				})
			} else {
				slog.Debug("No valid tool call found in object.", "object", kv)
			}
		}
	}

	// Valid JSON, no tool calls found
	if len(toolCalls) == 0 {
		slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
		return nil, errInvalidToolCall
	}

	return toolCalls, nil
}

// checkPrefix processes a string to find and handle a prefix pattern.
//
// Returns:
//   - The processed string with prefix removed if found
//   - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
func (p *Parser) checkPrefix(s string) (string, error) {
	if s == "" || p.prefix == "" {
		return s, nil
	}

	// Check for prefix at start of string
	if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
		// Found prefix at start - accumulate for potential tool
		p.prefixFound = true
		return cut, nil
	}

	// Check if prefix overlaps end of string
	if idx := suffixOverlap(s, p.prefix); idx != -1 {
		// Return everything except overlapping portion
		p.sb.Reset()
		p.sb.WriteString(s[idx:])
156
		return s[:idx], errAccumulateMore
157
158
159
160
161
162
163
164
	}

	// Check if prefix appears in middle of string
	if idx := strings.Index(s, p.prefix); idx != -1 {
		// Save remainder starting at prefix for next pass
		p.sb.Reset()
		p.sb.WriteString(strings.TrimSpace(s[idx:]))
		// Return everything before prefix
165
		return s[:idx], errAccumulateMore
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
	}

	// No partial prefix found
	return s, nil
}

// Add processes a string input to parse tool calls and content.
// It handles prefix detection and JSON parsing to extract tool calls.
//
// Returns:
//   - tools: Any parsed tool calls
//   - content: Non-tool call content
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
	if p.done {
		if p.index == 0 {
			// Return original string if no tool calls found at start
			return nil, s
		}
		// Return empty if no tool calls found after start
		return nil, ""
	}
	p.sb.WriteString(s)
	s = p.sb.String()

	// Check for prefix pattern in input
	s, err := p.checkPrefix(s)
	if err != nil {
		// Need more input to complete prefix
		return nil, s
	}

	// Exit if prefix exists in template, greedy parsing is off, and prefix not found
	if !p.parseLeadingJSON && !p.prefixFound {
		p.sb.Reset()
		return nil, s
	}

	toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
	if err != nil {
		if errors.Is(err, errAccumulateMore) {
			return nil, ""
		}
		p.sb.Reset()
		// Do not try parsing leading JSON if JSON not found
		p.parseLeadingJSON = false
		if p.prefix == "" {
			p.done = true
		}
		if p.index != 0 && p.prefix == "" {
			return nil, ""
		}
		if p.prefixFound {
			// Drop tokens since prefix was found
			return nil, ""
		}
		return nil, s
	}

	for _, tc := range toolCalls {
		tc.Function.Index = p.index
		p.index++
	}

	p.sb.Reset()
	return toolCalls, ""
}

// NewParser creates a new tool call parser from a template. It extracts the tool call format,
// prefix, and field names from the template to use for parsing tool calls from model output.
//
// Returns an error if the template does not contain valid tool call formatting.
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
	parsed, err := template.Parse(templateToProcess.Root.String())
	if err != nil {
		return nil, err
	}

	tt, err := toolTemplate(parsed)
	if err != nil {
		return nil, err
	}

	tp := toolPrefix(templateToProcess)

	name, arguments, err := extractToolArgs(tt)
	if err != nil {
		return nil, err
	}

	return &Parser{
		tmpl:             *tt,
		sb:               strings.Builder{},
		prefix:           tp,
		parseLeadingJSON: true,
		name:             name,
		arguments:        arguments,
	}, nil
}