Unverified Commit 9f8a18ec authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

tools: loosen tool parsing to allow for more formats (#11030)

parent 6b04cad7
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": [
"location",
"format"
]
}
}
}
]
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:
\ No newline at end of file
You are a knowledgeable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:
\ No newline at end of file
package tools package tools
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors"
"log/slog"
"strings" "strings"
gotmpl "text/template" "text/template"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
var ( type toolsState int
errInvalidToolCall = errors.New("invalid tool call format")
errAccumulateMore = errors.New("need to accumulate more content") const (
toolsState_LookingForTag toolsState = iota
toolsState_ToolCalling
toolsState_Done
) )
type Parser struct { type Parser struct {
greedyParseJSON bool tag string
prefix string names []string
prefixFound bool properties []string
tmpl gotmpl.Template
sb strings.Builder state toolsState
index int buffer []byte
name string n int
arguments string
} }
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // NewParser creates a new tool call parser from a model's chat
// // template and a list of provided tools.
// Parameters: func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
// - s: The string to parse return NewParserWithTag(tools, parseTag(tmpl))
// - name: The field name from template that identifies the tool call name }
// - arguments: The field name from template that identifies the tool call arguments
//
// Returns:
// - []api.ToolCall: The parsed tool calls if successful
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
// Check for balanced braces before attempting to parse
braceCount := 0
squareCount := 0
startIndex := -1
var rawToolCalls []string
s = strings.TrimSpace(s)
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
for i, c := range s {
switch c {
case '{':
braceCount++
if startIndex == -1 {
startIndex = i
}
case '}':
braceCount--
if braceCount == 0 {
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
startIndex = -1
}
case '[':
if trackSquareBrackets {
squareCount++
}
case ']':
if trackSquareBrackets {
squareCount--
}
}
// Negative means we have an extra closing brace/bracket func NewParserWithTag(tools []api.Tool, tag string) *Parser {
if braceCount < 0 || squareCount < 0 { var p Parser
return nil, errInvalidToolCall for _, t := range tools {
p.names = append(p.names, t.Function.Name)
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
} }
} }
p.tag = tag
return &p
}
// If braces/brackets aren't balanced, need more input // Add processes a string input to parse tool calls and content that
if braceCount > 0 || squareCount > 0 { // should be sent back to the user.
return nil, errAccumulateMore func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
if p.state == toolsState_Done {
return nil, s
} }
t := strings.TrimSpace(s) p.buffer = append(p.buffer, s...)
if len(t) == 0 {
return nil, errAccumulateMore
}
// If the input is a single square bracket, it's not a valid tool call
if t[0] == '[' && len(t) == 1 {
return nil, errAccumulateMore
}
// Attempt full unmarshal of the JSON if p.state == toolsState_LookingForTag {
var toolCalls []api.ToolCall i, found := p.findTag()
for _, rawToolCall := range rawToolCalls { if i == -1 {
var resp map[string]any content = string(p.buffer)
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { p.buffer = []byte{}
continue } else {
content = string(p.buffer[:i])
p.buffer = p.buffer[i:]
} }
// Collect nested objects that could contain tool calls // for models where { or [ are used as tool calling
objs := collect(resp) // tags, we only support parsing tools if the first non-
if len(objs) == 0 { // whitespace character is { or [
continue if p.tag == "{" || p.tag == "[" {
if strings.TrimSpace(content) != "" {
p.state = toolsState_Done
return nil, content + string(p.buffer)
}
} }
// Extract tool calls from objects if !found {
for _, kv := range objs { return nil, content
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
} else {
slog.Debug("No valid tool call found in object.", "object", kv)
}
} }
p.state = toolsState_ToolCalling
} }
// Valid JSON, no tool calls found for {
if len(toolCalls) == 0 { call := p.parseToolCall()
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) if call == nil {
return nil, errInvalidToolCall break
}
calls = append(calls, *call)
} }
return toolCalls, nil if p.done() {
p.state = toolsState_Done
content = string(p.buffer)
p.buffer = []byte{}
}
return calls, content
} }
// checkPrefix processes a string to find and handle a prefix pattern. // findTag searches the buffer to find and handle a tool calling tag
// // returning true if the tag was found and false otherwise, and
// Returns: // a string content signaling any content that should be sent back to the user
// - The processed string with prefix removed if found func (p *Parser) findTag() (int, bool) {
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful // First check for complete substring anywhere in s
func (p *Parser) checkPrefix(s string) (string, error) { if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
if s == "" || p.prefix == "" { return i, true
return s, nil
} }
// Check for prefix at start of string // Then check for partial suffix overlap
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { max := min(len(p.buffer), len(p.tag))
// Found prefix at start - accumulate for potential tool for i := max; i > 0; i-- {
p.prefixFound = true if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
return cut, nil return len(p.buffer) - i, false
}
} }
return -1, false
}
// Check if prefix overlaps end of string // parseToolCall finds the next complete tool call in the buffer
if idx := suffixOverlap(s, p.prefix); idx != -1 { // incrementing n and advancing the buffer.
// Return everything except overlapping portion func (p *Parser) parseToolCall() *api.ToolCall {
p.sb.Reset() var name string
p.sb.WriteString(s[idx:]) var args map[string]any
return s[:idx], errAccumulateMore var end int = len(p.buffer)
// find tool name
var i int
for _, n := range p.names {
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
name = n
end = i + len(n)
}
}
} }
// Check if prefix appears in middle of string if name == "" {
if idx := strings.Index(s, p.prefix); idx != -1 { return nil
// Save remainder starting at prefix for next pass
p.sb.Reset()
p.sb.WriteString(strings.TrimSpace(s[idx:]))
// Return everything before prefix
return s[:idx], errAccumulateMore
} }
// No partial prefix found if args, i = p.findArguments(); args == nil {
return s, nil return nil
} }
// Add processes a string input to parse tool calls and content. if i > end {
// It handles prefix detection and JSON parsing to extract tool calls. end = i
//
// Returns:
// - tools: Any parsed tool calls
// - content: Non-tool call content
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
p.sb.WriteString(s)
s = p.sb.String()
// Check for prefix pattern in input
s, err := p.checkPrefix(s)
if err != nil {
// Need more input to complete prefix
return nil, s
} }
// Exit if prefix exists in template, greedy parsing is off, and prefix not found tc := &api.ToolCall{
if !p.greedyParseJSON && !p.prefixFound { Function: api.ToolCallFunction{
p.sb.Reset() Name: name,
return nil, s Arguments: args,
Index: p.n,
},
} }
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) p.n++
if err != nil { p.buffer = p.buffer[end:]
if errors.Is(err, errAccumulateMore) { return tc
return nil, "" }
}
p.sb.Reset() // findArguments returns the first object that appears to be
// Only do greedy JSON parsing if there is no prefix from template // arguments and the position where the arguments end, returning nil and 0 if
if p.prefix != "" { // an invalid JSON object or non-arguments object is found first
p.greedyParseJSON = false func (p *Parser) findArguments() (map[string]any, int) {
if len(p.buffer) == 0 {
return nil, 0
}
var braces int
var start int = -1
var end int
var object []byte
// find any outer json object
for i, c := range p.buffer {
if c == '{' {
braces++
if start == -1 {
start = i
}
} }
if p.index != 0 && p.prefix == "" {
return nil, "" if c == '}' {
braces--
if braces == 0 && start != -1 {
end = i + 1
object = p.buffer[start:end]
break
}
} }
if p.prefixFound { }
// Drop tokens since prefix was found
return nil, "" if braces > 0 {
return nil, 0
}
var data map[string]any
// not valid json
if err := json.Unmarshal(object, &data); err != nil {
return nil, 0
}
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch v := obj.(type) {
case map[string]any:
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
}
}
for _, value := range v {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range v {
if result := find(item); result != nil {
return result
}
}
} }
return nil, s
return nil
} }
for _, tc := range toolCalls { result := find(data)
tc.Function.Index = p.index if result != nil {
p.index++ return result, end
} }
p.sb.Reset() return nil, 0
return toolCalls, ""
} }
// NewParser creates a new tool call parser from a template. It extracts the tool call format, // done checks if the parser is done parsing by looking
// prefix, and field names from the template to use for parsing tool calls from model output. // for closing tag. currently only } and ] are supported
// // for closing tags as {} or [] pairs may not always
// Returns an error if the template does not contain valid tool call formatting. // represent tool calls and we need to send the content back
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { func (p *Parser) done() bool {
parsed, err := template.Parse(templateToProcess.Root.String()) var open, close rune
if err != nil { switch p.tag {
return nil, err case "{":
open, close = '{', '}'
case "[":
open, close = '[', ']'
default:
return false
} }
tt, err := toolTemplate(parsed) var count int
if err != nil { for _, c := range p.buffer {
return nil, err if c == byte(open) {
count++
} else if c == byte(close) {
count--
if count == 0 {
return true
}
}
} }
tp := toolPrefix(templateToProcess) return false
}
// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
if p.n > 0 {
return ""
}
name, arguments, err := extractToolArgs(tt) if p.tag == "{" || p.tag == "[" {
if err != nil { return string(p.buffer)
return nil, err
} }
return &Parser{ return ""
tmpl: *tt,
sb: strings.Builder{},
prefix: tp,
greedyParseJSON: true,
name: name,
arguments: arguments,
}, nil
} }
package tools package tools
import ( import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing" "testing"
"text/template"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func readFile(t *testing.T, base, name string) *bytes.Buffer { func TestParser(t *testing.T) {
t.Helper() qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
bts, err := os.ReadFile(filepath.Join(base, name)) deepseek, err := template.New("deepseek").Parse("{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>{{end}}<|tool▁calls▁end|><|end▁of▁sentence|>{{end}}")
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("Failed to parse template: %v", err)
} }
return bytes.NewBuffer(bts) json, err := template.New("json").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`)
} if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
mistral, err := template.New("mistral").Parse(`{{if .ToolCalls}}[TOOL_CALLS] [{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}][/TOOL_CALLS]{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
list, err := template.New("list").Parse(`{{if .ToolCalls}}[{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}]{{end}}`)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"format": {
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"city": {
Type: api.PropertyType{"string"},
Description: "The city to get the temperature for",
},
},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the weather conditions for",
},
},
},
},
},
}
func TestParseJSONToolCalls(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
input string inputs []string
nameField string tmpl *template.Template
argsField string content string
wantToolCalls []api.ToolCall calls []api.ToolCall
wantErr error
prefix string
}{ }{
{ {
name: "valid single tool call", name: "no tool calls - just text",
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, inputs: []string{"Hello, how can I help you today?"},
nameField: "name", content: "Hello, how can I help you today?",
argsField: "arguments", tmpl: qwen,
wantToolCalls: []api.ToolCall{ calls: nil,
},
{
name: "empty input",
inputs: []string{""},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "tool call",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"location": "San Francisco"}}</tool_call>`},
content: "",
tmpl: qwen,
calls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test_tool", Index: 0,
Arguments: map[string]any{ Name: "get_conditions",
"arg1": "value1", Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco",
}, },
}, },
}, },
}, },
wantErr: nil, },
prefix: "", {
}, name: "text before tool call",
{ inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
name: "incomplete JSON", content: "Let me check the weather. ",
input: `{"name": "test_tool", "arguments": {"arg1": `, tmpl: qwen,
nameField: "name", calls: []api.ToolCall{
argsField: "arguments",
wantToolCalls: nil,
wantErr: errAccumulateMore,
prefix: "",
},
{
name: "invalid JSON",
input: `not json at all`,
nameField: "name",
argsField: "arguments",
wantToolCalls: nil,
wantErr: errInvalidToolCall,
prefix: "",
},
{
name: "missing required fields",
input: `{"other": "field"}`,
nameField: "name",
argsField: "arguments",
wantToolCalls: nil,
wantErr: errInvalidToolCall,
prefix: "",
},
{
name: "multiple tool calls in array",
input: `[
{"name": "tool1", "arguments": {"arg1": 1}},
{"name": "tool2", "arguments": {"arg2": "value"}}
]`,
nameField: "name",
argsField: "arguments",
wantToolCalls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool1", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg1": float64(1), Arguments: api.ToolCallFunctionArguments{
"city": "New York",
},
},
},
},
},
{
name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
content: "",
tmpl: mistral,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
}, },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool2", Index: 1,
Arguments: map[string]any{ Name: "get_conditions",
"arg2": "value", Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
}, },
}, },
}, },
}, },
wantErr: nil, },
prefix: "", {
}, name: "two tool calls",
{ inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
name: "multiple tool calls without array", content: "Okay, let's call both tools! ",
input: ` tmpl: qwen,
{"name": "tool1", "arguments": {"arg1": 1}}, calls: []api.ToolCall{
{"name": "tool2", "arguments": {"arg2": "value"}}
`,
nameField: "name",
argsField: "arguments",
wantToolCalls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool1", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg1": float64(1), Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
}, },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool2", Index: 1,
Arguments: map[string]any{ Name: "get_conditions",
"arg2": "value", Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
}, },
}, },
}, },
}, },
wantErr: nil, },
prefix: "", {
}, name: "deepseek",
{ inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
name: "multiple tool calls with text after", content: "<think>Wait, I need to call a tool</think>",
input: ` tmpl: deepseek,
{"name": "tool1", "arguments": {"arg1": 1}} text calls: []api.ToolCall{
{"name": "tool2", "arguments": {"arg2": "value"}} text
`,
nameField: "name",
argsField: "arguments",
wantToolCalls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool1", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg1": float64(1), Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
}, },
}, },
}, },
},
},
{
name: "deepseek incremental",
inputs: []string{
"<think>Wait",
", I need",
" to call",
" a tool</think><|too",
"l▁calls▁begin",
"|>",
"<|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n",
"```json\n",
"{\"city\": \"Tokyo\"}\n",
"```",
"<|tool▁c", "all▁end|>",
"<|tool▁calls▁end|>",
"<|end▁of▁sentence|>",
},
content: "<think>Wait, I need to call a tool</think>",
tmpl: deepseek,
calls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool2", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg2": "value", Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
}, },
}, },
}, },
}, },
wantErr: nil,
prefix: "",
}, },
{ {
name: "second tool call in array", name: "json",
input: ` inputs: []string{
, {"name": "tool2", "arguments": {"arg2": "value"}} "{",
`, "\"name\": \"get_temperature\",",
nameField: "name", "\"arguments\": {",
argsField: "arguments", "\"city\": \"Tokyo\"",
wantToolCalls: []api.ToolCall{ "}",
"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool2", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg2": "value", Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
}, },
}, },
}, },
}, },
wantErr: nil, },
prefix: "", {
}, name: "json maybe a tool call",
// a bad JSON would not return any tool calls or content as it would always accumulate more inputs: []string{
{ "{",
name: "unbalanced square brackets", "\"name\": \"get_temperature\",",
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, "\"arguments\": {",
nameField: "name", },
argsField: "arguments", content: "",
wantToolCalls: nil, tmpl: json,
wantErr: errAccumulateMore, calls: nil,
prefix: "", },
}, {
{ name: "json not a tool call",
name: "incomplete square brackets", inputs: []string{
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, "{",
nameField: "name", "\"name\": \"search\", ",
argsField: "arguments", "\"arguments\": {",
wantToolCalls: nil, "\"query\": \"What is the capital of Canada?\"",
wantErr: errAccumulateMore, "}",
prefix: "", "}",
}, },
{ content: "{\"name\": \"search\", \"arguments\": {\"query\": \"What is the capital of Canada?\"}}",
name: "nested arrays in arguments", tmpl: json,
input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, calls: nil,
nameField: "name", },
argsField: "arguments", {
wantToolCalls: []api.ToolCall{ name: "json object followed by tool call",
inputs: []string{
"{\"name\": \"jeff\"}",
"{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
},
content: "{\"name\": \"jeff\"}{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
tmpl: json,
},
{
name: "json object followed by tool call split",
inputs: []string{
"{\"name\": \"jeff\"} {",
"\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
},
content: "{\"name\": \"jeff\"} {\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}",
tmpl: json,
},
{
name: "json code",
inputs: []string{
"for { fmt.Println(\"hello\") }",
},
content: "for { fmt.Println(\"hello\") }",
tmpl: json,
},
{
name: "list multiple",
inputs: []string{
"[",
"{",
"\"name\": \"get_temperature\", ",
"\"arguments\": {",
"\"city\": \"London\"",
"}",
"},",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "tool1", Index: 0,
Arguments: map[string]any{ Name: "get_temperature",
"arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, Arguments: api.ToolCallFunctionArguments{
"city": "London",
}, },
}, },
}, },
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list partial",
inputs: []string{
"[",
"{",
"\"name\": \"search\", ",
"\"arguments\": {",
"\"query\": \"What is the capital of Canada?\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: nil,
},
{
name: "list not a tool call",
inputs: []string{
"[special",
" del",
"ivery]",
}, },
wantErr: nil, content: "[special delivery]",
prefix: "", tmpl: list,
calls: nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) parser := NewParser(tt.tmpl, tools)
if err != tt.wantErr { var calls []api.ToolCall
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) var content string
for _, input := range tt.inputs {
tcs, c := parser.Add(input)
calls = append(calls, tcs...)
content += c
} }
if len(gotCalls) != 0 && tt.wantErr != nil { if content != tt.content {
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) t.Errorf("Expected content %q, got %q", tt.content, content)
} }
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { if len(calls) != len(tt.calls) {
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) t.Fatalf("Expected %d tool calls, got %d", len(tt.calls), len(calls))
}
for i, want := range tt.calls {
if diff := cmp.Diff(calls[i], want); diff != "" {
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
}
} }
}) })
} }
} }
func TestParseToolCalls(t *testing.T) { func TestDone(t *testing.T) {
p := filepath.Join("testdata") tests := []struct {
t1 := api.ToolCall{ name string
Function: api.ToolCallFunction{ tag string
Name: "get_current_weather", buffer []byte
Arguments: api.ToolCallFunctionArguments{ want bool
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
}
t2 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
}
cases := []struct {
name string
model string
output string
expectedToolCall []api.ToolCall
expectedTokens string
}{ }{
{ {
name: "mistral malformed json with tool calls prefix", name: "empty",
model: "mistral", tag: "<tool_call>",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, buffer: []byte{},
expectedToolCall: []api.ToolCall{t1}, want: false,
expectedTokens: "",
}, },
{ {
name: "mistral multiple tool calls without prefix", name: "empty",
model: "mistral", tag: "<tool_call>",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, buffer: []byte{},
expectedToolCall: []api.ToolCall{t1, t2}, want: false,
expectedTokens: "",
}, },
{ {
name: "mistral tool calls with text between no prefix", name: "json open",
model: "mistral", tag: "{",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] buffer: []byte("{\"name\": \"get_weather\""),
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, want: false,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
}, },
{ {
name: "mistral valid json with tool calls prefix", name: "json closed",
model: "mistral", tag: "{",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, buffer: []byte("{\"name\": \"get_weather\"}"),
expectedToolCall: []api.ToolCall{t1, t2}, want: true,
expectedTokens: "",
}, },
{ {
name: "mistral multiple tool calls with text between and prefix", name: "json empty",
model: "mistral", tag: "{",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] buffer: []byte("{}"),
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, want: true,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "",
}, },
{ {
name: "mistral incomplete json with tool calls prefix", name: "list open",
model: "mistral", tag: "[",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, buffer: []byte("[{\"name\": \"get_weather\""),
expectedToolCall: []api.ToolCall{}, want: false,
expectedTokens: "",
}, },
{ {
name: "mistral invalid tool call with explanatory text no prefix", name: "list closed",
model: "mistral", tag: "[",
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: buffer: []byte("[{\"name\": \"get_weather\"}]"),
want: true,
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "mistral tool calls without prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "command r plus tool calls with json block format",
model: "command-r-plus",
output: "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```",
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
}, },
{ {
name: "firefunction tool calls with functools prefix", name: "list empty",
model: "firefunction", tag: "[",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, buffer: []byte("[]"),
expectedToolCall: []api.ToolCall{t1, t2}, want: true,
expectedTokens: "",
},
{
name: "llama3 groq single tool call with xml tags",
model: "llama3-groq-tool-use",
output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "xlam tool calls with wrapper object",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
}, },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.buffer,
}
got := parser.done()
if got != tt.want {
t.Errorf("done() = %t, want %t", got, tt.want)
}
})
}
}
func TestContent(t *testing.T) {
tests := []struct {
name string
tag string
content []byte
want string
n int
}{
{ {
name: "qwen2.5 single tool call with prefix", name: "empty",
model: "qwen2.5", content: []byte{},
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, tag: "{",
expectedToolCall: []api.ToolCall{t1}, want: "",
expectedTokens: "", n: 0,
}, },
{ {
name: "qwen2.5 multiple tool calls with and without prefix", name: "tag",
model: "qwen2.5", tag: "<tool_call>",
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`, content: []byte("<tool_call>{\"name\": \"get_temperature\""),
expectedToolCall: []api.ToolCall{t1, t1, t2}, want: "",
expectedTokens: "", n: 0,
}, },
{ {
name: "qwen2.5 plain text response no tool calls", name: "json object",
model: "qwen2.5", tag: "{",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", content: []byte("{\"name\": \"get_temperature\"}"),
expectedToolCall: []api.ToolCall{}, want: "{\"name\": \"get_temperature\"}",
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", n: 0,
}, },
{ {
name: "qwen2.5 tool calls with trailing text", name: "json object after called",
model: "qwen2.5", tag: "{",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, content: []byte("{\"hello\": \"world\"}"),
expectedToolCall: []api.ToolCall{t1, t2}, want: "{\"hello\": \"world\"}",
expectedTokens: "some tokens after call", n: 0,
}, },
{ {
name: "qwen2.5 tool calls with initial text", name: "json object after called",
model: "qwen2.5", tag: "{",
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, content: []byte("{\"hello\": \"world\"}"),
expectedToolCall: []api.ToolCall{}, want: "",
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, n: 1,
}, },
{ {
name: "qwen2.5 tool calls with prefix and trailing text", name: "list",
model: "qwen2.5", tag: "[",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`, content: []byte("[{\"name\": \"get_temperature\"}]"),
expectedToolCall: []api.ToolCall{t1, t2}, want: "[{\"name\": \"get_temperature\"}]",
expectedTokens: "", n: 0,
}, },
{ {
name: "qwen2.5 tool calls with prefix and initial text", name: "code",
model: "qwen2.5", tag: "{",
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`, content: []byte("{ fmt.Println(\"hello\")"),
expectedToolCall: []api.ToolCall{t1, t2}, want: "{ fmt.Println(\"hello\")",
expectedTokens: "some tokens before call", n: 0,
}, },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.content,
n: tt.n,
}
got := parser.Content()
if got != tt.want {
t.Errorf("Content() = %q, want %q", got, tt.want)
}
})
}
}
func TestFindTag(t *testing.T) {
cases := []struct {
name string
buffer []byte
tag string
i int
found bool
}{
{ {
name: "qwen2.5 tool calls without and with prefix", name: "no overlap",
model: "qwen2.5", buffer: []byte("hello world"),
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{t1, t2}, i: -1,
expectedTokens: "", found: false,
}, },
{ {
name: "qwen2.5 tool calls without and with prefix and text between", name: "full overlap",
model: "qwen2.5", buffer: []byte("<tool_call>"),
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call> some tokens after call`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{t1, t2}, i: 0,
expectedTokens: "some tokens between", found: true,
}, },
{ {
name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", name: "whitespace",
model: "qwen2.5", buffer: []byte(" <tool_call>\n {\"name\": \"bob\"}"),
output: `hi [{"options": "foo"}]`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{}, i: 4,
expectedTokens: `hi [{"options": "foo"}]`, found: true,
}, },
{ {
name: "qwen2.5 tool calls with prefix and invalid tool call", name: "over",
model: "qwen2.5", buffer: []byte("<tool_call>{\"name\""),
output: `<tool_call> [{"options": "foo"}] </tool_call> `, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{}, i: 0,
expectedTokens: ``, found: true,
}, },
{ {
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", name: "partial overlap",
model: "qwen3", buffer: []byte("text <tool_call>"),
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{t1}, i: 5,
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", found: true,
}, },
{ {
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", name: "overlap with extra",
model: "qwen3", buffer: []byte("<tool_calls><tool_call>"),
output: `<think>Okay, let me think what tool we should use...</think> <tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "<tool_calls>",
expectedToolCall: []api.ToolCall{t1}, i: 0,
expectedTokens: "<think>Okay, let me think what tool we should use...</think>", found: true,
}, },
{ {
name: "qwen3 empty think prefix without tool prefix and invalid tool call", name: "delimiter longer than string",
model: "qwen3", buffer: []byte("<tool>"),
output: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{}, i: -1,
expectedTokens: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, found: false,
}, },
{ {
name: "qwen3 empty think prefix with tool prefix and valid tool call", name: "empty string",
model: "qwen3", buffer: []byte{},
output: `<think></think><tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{t1}, i: -1,
expectedTokens: `<think></think>`, found: false,
}, },
{ {
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", name: "single char overlap",
model: "qwen3", buffer: []byte("test<"),
output: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{}, i: 4,
expectedTokens: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, found: false,
}, },
{ {
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", name: "partial tool call",
model: "qwen3", buffer: []byte("hello <tool_"),
output: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "<tool_call>",
expectedToolCall: []api.ToolCall{}, i: 6,
expectedTokens: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, found: false,
}, },
{ {
name: "qwen3 invalid tool call with malformed tool prefix", name: "square bracket",
model: "qwen3", buffer: []byte("calling tools: ["),
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, tag: "[",
expectedToolCall: []api.ToolCall{}, i: 15,
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`, found: true,
}, },
{ {
name: "model with prefix in template, no prefix in output", name: "bracket",
model: "qwen2.5", buffer: []byte("{\"name\": \"bob\""),
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, tag: "{",
expectedToolCall: []api.ToolCall{t1, t2}, i: 0,
expectedTokens: "", found: true,
}, },
{ {
name: "model with prefix in template, prefix in output", name: "bracket with whitespace",
model: "qwen2.5", buffer: []byte("\n\n{\n\"name\": \"bob\""),
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`, tag: "{",
expectedToolCall: []api.ToolCall{t1, t2}, i: 2,
expectedTokens: "", found: true,
}, },
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
parser := &Parser{
tag: tt.tag,
buffer: tt.buffer,
n: 0,
}
i, found := parser.findTag()
if i != tt.i {
t.Errorf("findTag(%q, %q) = %d; want %d", tt.buffer, tt.tag, i, tt.i)
}
if found != tt.found {
t.Errorf("findTag(%q, %q) = %t; want %t", tt.buffer, tt.tag, found, tt.found)
}
})
}
}
func TestFindArguments(t *testing.T) {
tests := []struct {
name string
buffer []byte
want map[string]any
}{
{ {
name: "model without prefix in template, no prefix in output", name: "empty string",
model: "llama3.2", buffer: []byte{},
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, want: nil,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
}, },
{ {
name: "model without prefix in template, no prefix in output, single tool call", name: "whitespace only",
model: "llama3.2", buffer: []byte(" \n\t "),
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, want: nil,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
}, },
{ {
name: "model without prefix in template, prefix in output, multiple tool calls in list", name: "unbalanced braces - missing closing",
model: "llama3.2", buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`, want: nil,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `<tool_call>`,
}, },
{ {
name: "model without prefix in template, prefix in output, individual tool calls", name: "unbalanced braces - extra closing",
model: "llama3.2", buffer: []byte(`{"format": "fahrenheit"}}`),
output: `<tool_call> {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`, want: map[string]any{
expectedToolCall: []api.ToolCall{t1, t2}, "format": "fahrenheit",
expectedTokens: `<tool_call>`, },
}, },
{ {
name: "model with prefix in template, no prefix in output, tokens before", name: "invalid JSON",
model: "qwen2.5", buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, want: nil,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
}, },
{ {
name: "model with prefix in template, prefix in output, tokens after", name: "valid json",
model: "qwen2.5", buffer: []byte(`{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, want: map[string]any{
expectedToolCall: []api.ToolCall{t1, t2}, "format": "fahrenheit",
expectedTokens: "", "location": "San Francisco, CA",
},
}, },
{ {
name: "model without prefix in template, no prefix in output, tokens after", name: "valid arguments with special tokens",
model: "llama3.2", buffer: []byte(`[tool]get_temperature[args]{"format": "fahrenheit", "location": "San Francisco, CA"}[end]`),
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, want: map[string]any{
expectedToolCall: []api.ToolCall{t1, t2}, "format": "fahrenheit",
expectedTokens: "", "location": "San Francisco, CA",
},
}, },
{ {
name: "model without prefix in template, no prefix in output, tokens before", name: "valid arguments in array",
model: "llama3.2", buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, want: map[string]any{
expectedToolCall: []api.ToolCall{t1, t2}, "format": "fahrenheit",
expectedTokens: `some tokens before`, "location": "San Francisco, CA",
},
}, },
{ {
name: "model without prefix in template, prefix in output, tokens after", name: "nested deep",
model: "llama3.2", buffer: []byte(`{"function": {"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}}`),
output: `<tool_call> want: map[string]any{
[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, "format": "fahrenheit",
expectedToolCall: []api.ToolCall{t1, t2}, "location": "San Francisco, CA",
expectedTokens: `<tool_call>`, },
}, },
{ {
name: "model without without prefix, match all jsons", name: "one arg",
model: "llama3.2", buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`, want: map[string]any{
expectedToolCall: []api.ToolCall{t1, t2}, "location": "San Francisco, CA",
expectedTokens: "model outputs some text", },
}, },
{ {
name: "model flushes tokens if tool call doesn't match", name: "two args",
model: "llama3.2", buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, want: map[string]any{
expectedToolCall: []api.ToolCall{}, "location": "San Francisco, CA",
expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, "format": "fahrenheit",
},
}, },
{ {
name: "model flushes tokens if tool call doesn't match array", name: "deepseek",
model: "llama3.2", buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, want: map[string]any{
expectedToolCall: []api.ToolCall{}, "location": "Tokyo",
expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, },
}, },
} }
var tools []api.Tool for _, tt := range tests {
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { parser := &Parser{
t.Fatal(err) buffer: tt.buffer,
} properties: []string{"format", "location"},
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) got, _ := parser.findArguments()
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) { if diff := cmp.Diff(got, tt.want); diff != "" {
actual := &bytes.Buffer{} // Create new buffer for each test t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { }
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
tp, err := NewParser(tmpl.Template)
if err != nil {
t.Fatal(err)
}
got := []api.ToolCall{}
var gotTokens strings.Builder
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
s := " " + tok
toolCalls, content := tp.Add(s)
if len(content) > 0 {
gotTokens.WriteString(content)
} else if len(toolCalls) > 0 {
got = append(got, toolCalls...)
}
}
// Compare tool calls if we expect any
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
}
// Compare tokens if we expect any
stripped := strings.TrimSpace(gotTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
}
})
}) })
} }
} }
package tools
import (
"bytes"
"encoding/json"
"errors"
"log/slog"
"slices"
"strings"
gotmpl "text/template"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
//
// Returns:
// - string: The extracted text following the first ".ToolCalls" condition found
// - bool: Whether a ".ToolCalls" condition was found in the template
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("template or tree is nil")
return "", false
}
var result string
var found bool
var walk func(nodes []parse.Node)
walk = func(nodes []parse.Node) {
for _, node := range nodes {
if found {
return
}
switch n := node.(type) {
case *parse.IfNode:
if isToolCallsNode(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
if tn, ok := innerNode.(*parse.TextNode); ok {
sb.Write(tn.Text)
} else {
// Stop at first non-text node
break
}
}
result = sb.String()
found = true
return
}
// Recurse into child nodes
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.ListNode:
walk(n.Nodes)
case *parse.RangeNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.WithNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
default:
// Continue to next node
continue
}
}
}
walk(tmpl.Tree.Root.Nodes)
return result, found
}
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
func isToolCallsNode(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
func toolPrefix(tmpl *gotmpl.Template) string {
tokenText, ok := extractToolCallsFormat(tmpl)
if !ok {
return ""
}
tokenText = strings.TrimSpace(tokenText)
tokenText = strings.ReplaceAll(tokenText, "\r", "")
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
return tokenText
}
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
//
// Returns:
// - *gotmpl.Template: The subtree containing the .ToolCalls range
// - error: Error if parsing failed
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, errors.New("failed to find tool template")
}
return tmpl, nil
}
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
//
// Returns:
// - int: The starting index in s where the suffix overlap begins
func suffixOverlap(s, prefix string) int {
max := min(len(prefix), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, prefix[:i]) {
return len(s) - i
}
}
return -1
}
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
//
// Returns:
// - string: The name of the tool call
// - string: The arguments of the tool call
// - error: Error if parsing failed
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return "", "", err
}
// Extract JSON object between curly braces
// JSON arrays are also valid as they will not be repeated in the template
output := b.String()
start := strings.Index(output, "{")
end := strings.LastIndex(output, "}")
if start == -1 || end == -1 || start > end {
return "", "", errors.New("no valid JSON object found in template output")
}
jsonStr := output[start : end+1]
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return "", "", err
}
// Find name and arguments fields
for k, v := range obj {
if str, ok := v.(string); ok && str == "@@name@@" {
name = k
} else if _, ok := v.(map[string]any); ok {
arguments = k
}
}
if name == "" || arguments == "" {
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
return "", "", errors.New("missing required fields in tool call template")
}
return name, arguments, nil
}
// collect recursively traverses an object to collect all nested maps
//
// Returns:
// - []map[string]any: A slice of all nested maps found in the object
func collect(obj any) []map[string]any {
var all []map[string]any
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
default:
return nil
}
return all
}
package tools
import (
"testing"
gotmpl "text/template"
"github.com/ollama/ollama/template"
)
func TestExtractToolCallsFormat(t *testing.T) {
cases := []struct {
name string
template string
want string
found bool
}{
{
name: "nil template",
template: "",
want: "",
found: false,
},
{
name: "basic tool call with text",
template: "{{if .ToolCalls}}Hello world{{end}}",
want: "Hello world",
found: true,
},
{
name: "tool call with json format",
template: "{{if .ToolCalls}}```json\n{{end}}",
want: "```json\n",
found: true,
},
{
name: "tool call in range",
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
want: "",
found: false,
},
{
name: "tool call with multiple text nodes",
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
want: "First text",
found: true,
},
{
name: "nested if without tool calls",
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
want: "",
found: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tc.template)
if err != nil && tc.template != "" {
t.Fatalf("failed to parse template: %v", err)
}
got, found := extractToolCallsFormat(tmpl)
if got != tc.want {
t.Errorf("got text %q, want %q", got, tc.want)
}
if found != tc.found {
t.Errorf("got found %v, want %v", found, tc.found)
}
})
}
}
func TestToolPrefix(t *testing.T) {
cases := []struct {
name string
template string
want string
}{
{
name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action: ```json",
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools[",
},
{
name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
want: "Hello, world! <tool_call>",
},
{
name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
want: "[tool_call] <tool_call>",
},
{
name: "single angle bracket tool call",
template: "{{if .ToolCalls}}<tool_call>{{end}}",
want: "<tool_call>",
},
{
name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call] <",
},
{
name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
want: "> <tool_call>",
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL] [",
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL][",
},
{
name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>",
},
{
name: "tool with no prefix",
template: "{{if .ToolCalls}}{{end}}",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got := toolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
})
}
}
func TestToolTemplate(t *testing.T) {
cases := []struct {
name string
template string
want bool
}{
{
name: "basic tool call range",
template: "{{range .ToolCalls}}test{{end}}",
want: true,
},
{
name: "no tool calls",
template: "{{range .Other}}test{{end}}",
want: false,
},
{
name: "nested tool calls",
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
want: true,
},
{
name: "empty template",
template: "",
want: false,
},
{
name: "tool calls in if statement",
template: "{{if .ToolCalls}}test{{end}}",
want: false,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
parsed, err := template.Parse(tmpl.Root.String())
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
_, err = toolTemplate(parsed)
if err != nil && tt.want {
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
}
})
}
}
func TestSuffixOverlap(t *testing.T) {
cases := []struct {
name string
s string
d string
want int
}{
{
name: "no overlap",
s: "hello world",
d: "<tool_call>",
want: -1,
},
{
name: "full overlap",
s: "<tool_call>",
d: "<tool_call>",
want: 0,
},
{
name: "partial overlap",
s: "text <tool_call>",
d: "<tool_call>",
want: 5,
},
{
name: "delimiter longer than string",
s: "<tool>",
d: "<tool_call>",
want: -1,
},
{
name: "empty string",
s: "",
d: "<tool_call>",
want: -1,
},
{
name: "empty delimiter",
s: "<tool_call>",
d: "",
want: -1,
},
{
name: "single char overlap",
s: "test<",
d: "<tool_call>",
want: 4,
},
{
name: "partial tool call",
s: "hello <tool_",
d: "<tool_call>",
want: 6,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := suffixOverlap(tt.s, tt.d)
if got != tt.want {
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
}
})
}
}
func TestExtractToolArgs(t *testing.T) {
cases := []struct {
name string
template string
wantName string
wantArgs string
wantErr bool
}{
{
name: "basic tool call",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`,
wantName: "name",
wantArgs: "parameters",
wantErr: false,
},
{
name: "tool call with whitespace",
template: `{{range .ToolCalls}}
{"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}}
{{end}}`,
wantName: "name",
wantArgs: "parameters",
wantErr: false,
},
{
name: "tool call with extra content",
template: `Before {{range .ToolCalls}}
{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "no tool calls",
template: `{{if .Something}}no tools here{{end}}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "empty template",
template: ``,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "prefix within tool call",
template: `{{- if .ToolCalls }}
{{ range .ToolCalls }}
<tool_call>
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
</tool_call>{{ end }}{{- end }}`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "JSON array",
template: `{{ range .ToolCalls }}
[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "invalid JSON",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "missing name field",
template: `{{ range .ToolCalls }}
{"parameters": {{ .Function.Arguments }}}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "missing arguments field",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}"}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "malformed JSON",
template: `{{ range .ToolCalls }}
{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
gotName, gotArgs, err := extractToolArgs(tmpl)
if (err != nil) != tt.wantErr {
t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
return
}
if gotName != tt.wantName {
t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName)
}
if gotArgs != tt.wantArgs {
t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs)
}
})
}
}
func TestCollect(t *testing.T) {
cases := []struct {
name string
obj any
want []map[string]any
}{
{
name: "simple map",
obj: map[string]any{
"key": "value",
},
want: []map[string]any{
{"key": "value"},
},
},
{
name: "nested map",
obj: map[string]any{
"outer": map[string]any{
"inner": "value",
},
},
want: []map[string]any{
{"outer": map[string]any{"inner": "value"}},
{"inner": "value"},
},
},
{
name: "array of maps",
obj: []any{
map[string]any{"key1": "val1"},
map[string]any{"key2": "val2"},
},
want: []map[string]any{
{"key1": "val1"},
{"key2": "val2"},
},
},
{
name: "deeply nested",
obj: map[string]any{
"l1": map[string]any{
"l2": map[string]any{
"l3": "value",
},
},
},
want: []map[string]any{
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
{"l2": map[string]any{"l3": "value"}},
{"l3": "value"},
},
},
{
name: "non-map value",
obj: "string",
want: nil,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := collect(tt.obj)
if len(got) != len(tt.want) {
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
return
}
// Compare each map in the result
for i := range tt.want {
if !mapsEqual(got[i], tt.want[i]) {
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
}
}
})
}
}
// mapsEqual compares two maps for deep equality
func mapsEqual(m1, m2 map[string]any) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
v2, ok := m2[k]
if !ok {
return false
}
switch val1 := v1.(type) {
case map[string]any:
val2, ok := v2.(map[string]any)
if !ok || !mapsEqual(val1, val2) {
return false
}
default:
if v1 != v2 {
return false
}
}
}
return true
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment