Unverified Commit 9f8a18ec authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

tools: loosen tool parsing to allow for more formats (#11030)

parent 6b04cad7
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": [
"location",
"format"
]
}
}
}
]
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:
\ No newline at end of file
You are a knowledgeable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:
\ No newline at end of file
package tools package tools
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors"
"log/slog"
"strings" "strings"
gotmpl "text/template" "text/template"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
var ( type toolsState int
errInvalidToolCall = errors.New("invalid tool call format")
errAccumulateMore = errors.New("need to accumulate more content") const (
toolsState_LookingForTag toolsState = iota
toolsState_ToolCalling
toolsState_Done
) )
type Parser struct { type Parser struct {
greedyParseJSON bool tag string
prefix string names []string
prefixFound bool properties []string
tmpl gotmpl.Template
sb strings.Builder state toolsState
index int buffer []byte
name string n int
arguments string
} }
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. // NewParser creates a new tool call parser from a model's chat
// // template and a list of provided tools.
// Parameters: func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
// - s: The string to parse return NewParserWithTag(tools, parseTag(tmpl))
// - name: The field name from template that identifies the tool call name }
// - arguments: The field name from template that identifies the tool call arguments
// func NewParserWithTag(tools []api.Tool, tag string) *Parser {
// Returns: var p Parser
// - []api.ToolCall: The parsed tool calls if successful for _, t := range tools {
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful p.names = append(p.names, t.Function.Name)
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { for r := range t.Function.Parameters.Properties {
// Check for balanced braces before attempting to parse p.properties = append(p.properties, r)
braceCount := 0 }
squareCount := 0 }
startIndex := -1 p.tag = tag
var rawToolCalls []string return &p
s = strings.TrimSpace(s) }
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. // Add processes a string input to parse tool calls and content that
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") // should be sent back to the user.
for i, c := range s { func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
switch c { if p.state == toolsState_Done {
case '{': return nil, s
braceCount++ }
if startIndex == -1 {
startIndex = i p.buffer = append(p.buffer, s...)
}
case '}': if p.state == toolsState_LookingForTag {
braceCount-- i, found := p.findTag()
if braceCount == 0 { if i == -1 {
rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) content = string(p.buffer)
startIndex = -1 p.buffer = []byte{}
}
case '[':
if trackSquareBrackets {
squareCount++
}
case ']':
if trackSquareBrackets {
squareCount--
}
}
// Negative means we have an extra closing brace/bracket
if braceCount < 0 || squareCount < 0 {
return nil, errInvalidToolCall
}
}
// If braces/brackets aren't balanced, need more input
if braceCount > 0 || squareCount > 0 {
return nil, errAccumulateMore
}
t := strings.TrimSpace(s)
if len(t) == 0 {
return nil, errAccumulateMore
}
// If the input is a single square bracket, it's not a valid tool call
if t[0] == '[' && len(t) == 1 {
return nil, errAccumulateMore
}
// Attempt full unmarshal of the JSON
var toolCalls []api.ToolCall
for _, rawToolCall := range rawToolCalls {
var resp map[string]any
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
continue
}
// Collect nested objects that could contain tool calls
objs := collect(resp)
if len(objs) == 0 {
continue
}
// Extract tool calls from objects
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
} else { } else {
slog.Debug("No valid tool call found in object.", "object", kv) content = string(p.buffer[:i])
p.buffer = p.buffer[i:]
}
// for models where { or [ are used as tool calling
// tags, we only support parsing tools if the first non-
// whitespace character is { or [
if p.tag == "{" || p.tag == "[" {
if strings.TrimSpace(content) != "" {
p.state = toolsState_Done
return nil, content + string(p.buffer)
}
}
if !found {
return nil, content
} }
p.state = toolsState_ToolCalling
}
for {
call := p.parseToolCall()
if call == nil {
break
} }
calls = append(calls, *call)
} }
// Valid JSON, no tool calls found if p.done() {
if len(toolCalls) == 0 { p.state = toolsState_Done
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) content = string(p.buffer)
return nil, errInvalidToolCall p.buffer = []byte{}
} }
return toolCalls, nil return calls, content
} }
// checkPrefix processes a string to find and handle a prefix pattern. // findTag searches the buffer to find and handle a tool calling tag
// // returning true if the tag was found and false otherwise, and
// Returns: // a string content signaling any content that should be sent back to the user
// - The processed string with prefix removed if found func (p *Parser) findTag() (int, bool) {
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful // First check for complete substring anywhere in s
func (p *Parser) checkPrefix(s string) (string, error) { if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
if s == "" || p.prefix == "" { return i, true
return s, nil }
// Then check for partial suffix overlap
max := min(len(p.buffer), len(p.tag))
for i := max; i > 0; i-- {
if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
return len(p.buffer) - i, false
}
}
return -1, false
}
// parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall {
var name string
var args map[string]any
var end int = len(p.buffer)
// find tool name
var i int
for _, n := range p.names {
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end {
name = n
end = i + len(n)
}
}
}
if name == "" {
return nil
} }
// Check for prefix at start of string if args, i = p.findArguments(); args == nil {
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { return nil
// Found prefix at start - accumulate for potential tool
p.prefixFound = true
return cut, nil
} }
// Check if prefix overlaps end of string if i > end {
if idx := suffixOverlap(s, p.prefix); idx != -1 { end = i
// Return everything except overlapping portion
p.sb.Reset()
p.sb.WriteString(s[idx:])
return s[:idx], errAccumulateMore
} }
// Check if prefix appears in middle of string tc := &api.ToolCall{
if idx := strings.Index(s, p.prefix); idx != -1 { Function: api.ToolCallFunction{
// Save remainder starting at prefix for next pass Name: name,
p.sb.Reset() Arguments: args,
p.sb.WriteString(strings.TrimSpace(s[idx:])) Index: p.n,
// Return everything before prefix },
return s[:idx], errAccumulateMore
} }
// No partial prefix found p.n++
return s, nil p.buffer = p.buffer[end:]
return tc
} }
// Add processes a string input to parse tool calls and content. // findArguments returns the first object that appears to be
// It handles prefix detection and JSON parsing to extract tool calls. // arguments and the position where the arguments end, returning nil and 0 if
// // an invalid JSON object or non-arguments object is found first
// Returns: func (p *Parser) findArguments() (map[string]any, int) {
// - tools: Any parsed tool calls if len(p.buffer) == 0 {
// - content: Non-tool call content return nil, 0
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
p.sb.WriteString(s)
s = p.sb.String()
// Check for prefix pattern in input
s, err := p.checkPrefix(s)
if err != nil {
// Need more input to complete prefix
return nil, s
} }
// Exit if prefix exists in template, greedy parsing is off, and prefix not found var braces int
if !p.greedyParseJSON && !p.prefixFound { var start int = -1
p.sb.Reset() var end int
return nil, s var object []byte
// find any outer json object
for i, c := range p.buffer {
if c == '{' {
braces++
if start == -1 {
start = i
}
} }
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) if c == '}' {
if err != nil { braces--
if errors.Is(err, errAccumulateMore) { if braces == 0 && start != -1 {
return nil, "" end = i + 1
object = p.buffer[start:end]
break
} }
p.sb.Reset()
// Only do greedy JSON parsing if there is no prefix from template
if p.prefix != "" {
p.greedyParseJSON = false
} }
if p.index != 0 && p.prefix == "" {
return nil, ""
} }
if p.prefixFound {
// Drop tokens since prefix was found if braces > 0 {
return nil, "" return nil, 0
}
var data map[string]any
// not valid json
if err := json.Unmarshal(object, &data); err != nil {
return nil, 0
}
var find func(obj any) map[string]any
find = func(obj any) map[string]any {
switch v := obj.(type) {
case map[string]any:
// check if the object keys are valid tool properties
// TODO (jmorganca): check only sets of properties that
// go together instead of the entire set
for _, prop := range p.properties {
if _, exists := v[prop]; exists {
return v
}
}
for _, value := range v {
if result := find(value); result != nil {
return result
}
}
case []any:
for _, item := range v {
if result := find(item); result != nil {
return result
}
} }
return nil, s
} }
for _, tc := range toolCalls { return nil
tc.Function.Index = p.index
p.index++
} }
p.sb.Reset() result := find(data)
return toolCalls, "" if result != nil {
return result, end
}
return nil, 0
} }
// NewParser creates a new tool call parser from a template. It extracts the tool call format, // done checks if the parser is done parsing by looking
// prefix, and field names from the template to use for parsing tool calls from model output. // for closing tag. currently only } and ] are supported
// // for closing tags as {} or [] pairs may not always
// Returns an error if the template does not contain valid tool call formatting. // represent tool calls and we need to send the content back
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { func (p *Parser) done() bool {
parsed, err := template.Parse(templateToProcess.Root.String()) var open, close rune
if err != nil { switch p.tag {
return nil, err case "{":
open, close = '{', '}'
case "[":
open, close = '[', ']'
default:
return false
} }
tt, err := toolTemplate(parsed) var count int
if err != nil { for _, c := range p.buffer {
return nil, err if c == byte(open) {
count++
} else if c == byte(close) {
count--
if count == 0 {
return true
}
}
} }
tp := toolPrefix(templateToProcess) return false
}
// Content returns any remaining content that
// should be sent to the user. This should be the empty string
// string unless the tag is { or [ and a tool call was not found
func (p *Parser) Content() string {
if p.n > 0 {
return ""
}
name, arguments, err := extractToolArgs(tt) if p.tag == "{" || p.tag == "[" {
if err != nil { return string(p.buffer)
return nil, err
} }
return &Parser{ return ""
tmpl: *tt,
sb: strings.Builder{},
prefix: tp,
greedyParseJSON: true,
name: name,
arguments: arguments,
}, nil
} }
This diff is collapsed.
package tools
import (
"bytes"
"encoding/json"
"errors"
"log/slog"
"slices"
"strings"
gotmpl "text/template"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
//
// Returns:
// - string: The extracted text following the first ".ToolCalls" condition found
// - bool: Whether a ".ToolCalls" condition was found in the template
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("template or tree is nil")
return "", false
}
var result string
var found bool
var walk func(nodes []parse.Node)
walk = func(nodes []parse.Node) {
for _, node := range nodes {
if found {
return
}
switch n := node.(type) {
case *parse.IfNode:
if isToolCallsNode(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
if tn, ok := innerNode.(*parse.TextNode); ok {
sb.Write(tn.Text)
} else {
// Stop at first non-text node
break
}
}
result = sb.String()
found = true
return
}
// Recurse into child nodes
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.ListNode:
walk(n.Nodes)
case *parse.RangeNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.WithNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
default:
// Continue to next node
continue
}
}
}
walk(tmpl.Tree.Root.Nodes)
return result, found
}
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
func isToolCallsNode(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
func toolPrefix(tmpl *gotmpl.Template) string {
tokenText, ok := extractToolCallsFormat(tmpl)
if !ok {
return ""
}
tokenText = strings.TrimSpace(tokenText)
tokenText = strings.ReplaceAll(tokenText, "\r", "")
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
return tokenText
}
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
//
// Returns:
// - *gotmpl.Template: The subtree containing the .ToolCalls range
// - error: Error if parsing failed
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, errors.New("failed to find tool template")
}
return tmpl, nil
}
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
//
// Returns:
// - int: The starting index in s where the suffix overlap begins
func suffixOverlap(s, prefix string) int {
max := min(len(prefix), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, prefix[:i]) {
return len(s) - i
}
}
return -1
}
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
//
// Returns:
// - string: The name of the tool call
// - string: The arguments of the tool call
// - error: Error if parsing failed
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return "", "", err
}
// Extract JSON object between curly braces
// JSON arrays are also valid as they will not be repeated in the template
output := b.String()
start := strings.Index(output, "{")
end := strings.LastIndex(output, "}")
if start == -1 || end == -1 || start > end {
return "", "", errors.New("no valid JSON object found in template output")
}
jsonStr := output[start : end+1]
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return "", "", err
}
// Find name and arguments fields
for k, v := range obj {
if str, ok := v.(string); ok && str == "@@name@@" {
name = k
} else if _, ok := v.(map[string]any); ok {
arguments = k
}
}
if name == "" || arguments == "" {
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
return "", "", errors.New("missing required fields in tool call template")
}
return name, arguments, nil
}
// collect recursively traverses an object to collect all nested maps
//
// Returns:
// - []map[string]any: A slice of all nested maps found in the object
func collect(obj any) []map[string]any {
var all []map[string]any
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
default:
return nil
}
return all
}
package tools
import (
"testing"
gotmpl "text/template"
"github.com/ollama/ollama/template"
)
func TestExtractToolCallsFormat(t *testing.T) {
cases := []struct {
name string
template string
want string
found bool
}{
{
name: "nil template",
template: "",
want: "",
found: false,
},
{
name: "basic tool call with text",
template: "{{if .ToolCalls}}Hello world{{end}}",
want: "Hello world",
found: true,
},
{
name: "tool call with json format",
template: "{{if .ToolCalls}}```json\n{{end}}",
want: "```json\n",
found: true,
},
{
name: "tool call in range",
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
want: "",
found: false,
},
{
name: "tool call with multiple text nodes",
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
want: "First text",
found: true,
},
{
name: "nested if without tool calls",
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
want: "",
found: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tc.template)
if err != nil && tc.template != "" {
t.Fatalf("failed to parse template: %v", err)
}
got, found := extractToolCallsFormat(tmpl)
if got != tc.want {
t.Errorf("got text %q, want %q", got, tc.want)
}
if found != tc.found {
t.Errorf("got found %v, want %v", found, tc.found)
}
})
}
}
func TestToolPrefix(t *testing.T) {
cases := []struct {
name string
template string
want string
}{
{
name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action: ```json",
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools[",
},
{
name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
want: "Hello, world! <tool_call>",
},
{
name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
want: "[tool_call] <tool_call>",
},
{
name: "single angle bracket tool call",
template: "{{if .ToolCalls}}<tool_call>{{end}}",
want: "<tool_call>",
},
{
name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call] <",
},
{
name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
want: "> <tool_call>",
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL] [",
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL][",
},
{
name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>",
},
{
name: "tool with no prefix",
template: "{{if .ToolCalls}}{{end}}",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got := toolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
})
}
}
func TestToolTemplate(t *testing.T) {
cases := []struct {
name string
template string
want bool
}{
{
name: "basic tool call range",
template: "{{range .ToolCalls}}test{{end}}",
want: true,
},
{
name: "no tool calls",
template: "{{range .Other}}test{{end}}",
want: false,
},
{
name: "nested tool calls",
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
want: true,
},
{
name: "empty template",
template: "",
want: false,
},
{
name: "tool calls in if statement",
template: "{{if .ToolCalls}}test{{end}}",
want: false,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
parsed, err := template.Parse(tmpl.Root.String())
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
_, err = toolTemplate(parsed)
if err != nil && tt.want {
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
}
})
}
}
func TestSuffixOverlap(t *testing.T) {
cases := []struct {
name string
s string
d string
want int
}{
{
name: "no overlap",
s: "hello world",
d: "<tool_call>",
want: -1,
},
{
name: "full overlap",
s: "<tool_call>",
d: "<tool_call>",
want: 0,
},
{
name: "partial overlap",
s: "text <tool_call>",
d: "<tool_call>",
want: 5,
},
{
name: "delimiter longer than string",
s: "<tool>",
d: "<tool_call>",
want: -1,
},
{
name: "empty string",
s: "",
d: "<tool_call>",
want: -1,
},
{
name: "empty delimiter",
s: "<tool_call>",
d: "",
want: -1,
},
{
name: "single char overlap",
s: "test<",
d: "<tool_call>",
want: 4,
},
{
name: "partial tool call",
s: "hello <tool_",
d: "<tool_call>",
want: 6,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := suffixOverlap(tt.s, tt.d)
if got != tt.want {
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
}
})
}
}
func TestExtractToolArgs(t *testing.T) {
cases := []struct {
name string
template string
wantName string
wantArgs string
wantErr bool
}{
{
name: "basic tool call",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`,
wantName: "name",
wantArgs: "parameters",
wantErr: false,
},
{
name: "tool call with whitespace",
template: `{{range .ToolCalls}}
{"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}}
{{end}}`,
wantName: "name",
wantArgs: "parameters",
wantErr: false,
},
{
name: "tool call with extra content",
template: `Before {{range .ToolCalls}}
{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "no tool calls",
template: `{{if .Something}}no tools here{{end}}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "empty template",
template: ``,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "prefix within tool call",
template: `{{- if .ToolCalls }}
{{ range .ToolCalls }}
<tool_call>
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
</tool_call>{{ end }}{{- end }}`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "JSON array",
template: `{{ range .ToolCalls }}
[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`,
wantName: "name",
wantArgs: "arguments",
wantErr: false,
},
{
name: "invalid JSON",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "missing name field",
template: `{{ range .ToolCalls }}
{"parameters": {{ .Function.Arguments }}}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "missing arguments field",
template: `{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}"}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
{
name: "malformed JSON",
template: `{{ range .ToolCalls }}
{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`,
wantName: "",
wantArgs: "",
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
gotName, gotArgs, err := extractToolArgs(tmpl)
if (err != nil) != tt.wantErr {
t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
return
}
if gotName != tt.wantName {
t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName)
}
if gotArgs != tt.wantArgs {
t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs)
}
})
}
}
func TestCollect(t *testing.T) {
cases := []struct {
name string
obj any
want []map[string]any
}{
{
name: "simple map",
obj: map[string]any{
"key": "value",
},
want: []map[string]any{
{"key": "value"},
},
},
{
name: "nested map",
obj: map[string]any{
"outer": map[string]any{
"inner": "value",
},
},
want: []map[string]any{
{"outer": map[string]any{"inner": "value"}},
{"inner": "value"},
},
},
{
name: "array of maps",
obj: []any{
map[string]any{"key1": "val1"},
map[string]any{"key2": "val2"},
},
want: []map[string]any{
{"key1": "val1"},
{"key2": "val2"},
},
},
{
name: "deeply nested",
obj: map[string]any{
"l1": map[string]any{
"l2": map[string]any{
"l3": "value",
},
},
},
want: []map[string]any{
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
{"l2": map[string]any{"l3": "value"}},
{"l3": "value"},
},
},
{
name: "non-map value",
obj: "string",
want: nil,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := collect(tt.obj)
if len(got) != len(tt.want) {
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
return
}
// Compare each map in the result
for i := range tt.want {
if !mapsEqual(got[i], tt.want[i]) {
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
}
}
})
}
}
// mapsEqual compares two maps for deep equality
func mapsEqual(m1, m2 map[string]any) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
v2, ok := m2[k]
if !ok {
return false
}
switch val1 := v1.(type) {
case map[string]any:
val2, ok := v2.(map[string]any)
if !ok || !mapsEqual(val1, val2) {
return false
}
default:
if v1 != v2 {
return false
}
}
}
return true
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment