Unverified Commit e51dead6 authored by Devon Rifkin's avatar Devon Rifkin Committed by GitHub
Browse files

preserve tool definition and call JSON ordering (#13525)

* preserve tool definition and call JSON ordering

This is another iteration of
<https://github.com/ollama/ollama/pull/12518>, but this time we've
simplified things by relaxing the competing requirements of being
compatible AND order-preserving with templates (vs. renderers). We
maintain backwards compatibility at the cost of not guaranteeing order
for templates. We plan on moving more and more models to renderers,
which have been updated to use these new data types, and additionally
we could add an opt-in way of templates getting an order-preserved list
(e.g., via sibling template vars)

* orderedmap_test: remove testify
parent d087e46b
...@@ -29,12 +29,12 @@ func getTestTools() []api.Tool { ...@@ -29,12 +29,12 @@ func getTestTools() []api.Tool {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA", Description: "The city and state, e.g. San Francisco, CA",
}, },
}, }),
}, },
}, },
}, },
...@@ -46,12 +46,12 @@ func getTestTools() []api.Tool { ...@@ -46,12 +46,12 @@ func getTestTools() []api.Tool {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"expression"}, Required: []string{"expression"},
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"expression": { "expression": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The mathematical expression to calculate", Description: "The mathematical expression to calculate",
}, },
}, }),
}, },
}, },
}, },
...@@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { ...@@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "San Francisco", "location": "San Francisco",
}, }),
}, },
}, },
}, },
...@@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { ...@@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"expression": "2+2", "expression": "2+2",
}, }),
}, },
}, },
}, },
......
...@@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error { ...@@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
} else if !v.forceLegacy && slices.Contains(vars, "messages") { } else if !v.forceLegacy && slices.Contains(vars, "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": convertMessagesForTemplate(messages),
"Tools": v.Tools, "Tools": convertToolsForTemplate(v.Tools),
"Response": "", "Response": "",
"Think": v.Think, "Think": v.Think,
"ThinkLevel": v.ThinkLevel, "ThinkLevel": v.ThinkLevel,
...@@ -373,6 +373,118 @@ func collate(msgs []api.Message) (string, []*api.Message) { ...@@ -373,6 +373,118 @@ func collate(msgs []api.Message) (string, []*api.Message) {
return strings.Join(system, "\n\n"), collated return strings.Join(system, "\n\n"), collated
} }
// templateTools is a slice of templateTool that marshals to JSON.
type templateTools []templateTool
func (t templateTools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function templateToolFunction `json:"function"`
}
type templateToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters templateToolFunctionParameters `json:"parameters"`
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
// with Arguments as a regular map for template ranging.
type templateToolCall struct {
ID string
Function templateToolCallFunction
}
type templateToolCallFunction struct {
Index int
Name string
Arguments map[string]any
}
// templateMessage is a template-compatible representation of api.Message
// with ToolCalls converted for template use.
type templateMessage struct {
Role string
Content string
Thinking string
Images []api.ImageData
ToolCalls []templateToolCall
ToolName string
ToolCallID string
}
// convertToolsForTemplate converts Tools to template-compatible format.
func convertToolsForTemplate(tools api.Tools) templateTools {
if tools == nil {
return nil
}
result := make(templateTools, len(tools))
for i, tool := range tools {
result[i] = templateTool{
Type: tool.Type,
Items: tool.Items,
Function: templateToolFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: templateToolFunctionParameters{
Type: tool.Function.Parameters.Type,
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: tool.Function.Parameters.Properties.ToMap(),
},
},
}
}
return result
}
// convertMessagesForTemplate converts Messages to template-compatible format.
func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
if messages == nil {
return nil
}
result := make([]*templateMessage, len(messages))
for i, msg := range messages {
var toolCalls []templateToolCall
for _, tc := range msg.ToolCalls {
toolCalls = append(toolCalls, templateToolCall{
ID: tc.ID,
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
},
})
}
result[i] = &templateMessage{
Role: msg.Role,
Content: msg.Content,
Thinking: msg.Thinking,
Images: msg.Images,
ToolCalls: toolCalls,
ToolName: msg.ToolName,
ToolCallID: msg.ToolCallID,
}
}
return result
}
// Identifiers walks the node tree returning any identifiers it finds along the way // Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) ([]string, error) { func Identifiers(n parse.Node) ([]string, error) {
switch n := n.(type) { switch n := n.(type) {
......
...@@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall { ...@@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall {
return nil return nil
} }
var args map[string]any var argsMap map[string]any
if found, i := findArguments(tool, p.buffer); found == nil { if found, i := findArguments(tool, p.buffer); found == nil {
return nil return nil
} else { } else {
args = found argsMap = found
if i > end { if i > end {
end = i end = i
} }
} }
args := api.NewToolCallFunctionArguments()
for k, v := range argsMap {
args.Set(k, v)
}
tc := &api.ToolCall{ tc := &api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: tool.Function.Name, Name: tool.Function.Name,
......
...@@ -9,6 +9,29 @@ import ( ...@@ -9,6 +9,29 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value (order-insensitive)
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`) qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
if err != nil { if err != nil {
...@@ -44,7 +67,7 @@ func TestParser(t *testing.T) { ...@@ -44,7 +67,7 @@ func TestParser(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"city"}, Required: []string{"city"},
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"format": { "format": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The format to return the temperature in", Description: "The format to return the temperature in",
...@@ -54,7 +77,7 @@ func TestParser(t *testing.T) { ...@@ -54,7 +77,7 @@ func TestParser(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city to get the temperature for", Description: "The city to get the temperature for",
}, },
}, }),
}, },
}, },
}, },
...@@ -65,12 +88,12 @@ func TestParser(t *testing.T) { ...@@ -65,12 +88,12 @@ func TestParser(t *testing.T) {
Description: "Retrieve the current weather conditions for a given location", Description: "Retrieve the current weather conditions for a given location",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The location to get the weather conditions for", Description: "The location to get the weather conditions for",
}, },
}, }),
}, },
}, },
}, },
...@@ -95,12 +118,12 @@ func TestParser(t *testing.T) { ...@@ -95,12 +118,12 @@ func TestParser(t *testing.T) {
Description: "Get the address of a given location", Description: "Get the address of a given location",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The location to get the address for", Description: "The location to get the address for",
}, },
}, }),
}, },
}, },
}, },
...@@ -111,7 +134,7 @@ func TestParser(t *testing.T) { ...@@ -111,7 +134,7 @@ func TestParser(t *testing.T) {
Description: "Add two numbers", Description: "Add two numbers",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: map[string]api.ToolProperty{ Properties: testPropsMap(map[string]api.ToolProperty{
"a": { "a": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The first number to add", Description: "The first number to add",
...@@ -120,7 +143,7 @@ func TestParser(t *testing.T) { ...@@ -120,7 +143,7 @@ func TestParser(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The second number to add", Description: "The second number to add",
}, },
}, }),
}, },
}, },
}, },
...@@ -157,9 +180,9 @@ func TestParser(t *testing.T) { ...@@ -157,9 +180,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "San Francisco", "location": "San Francisco",
}, }),
}, },
}, },
}, },
...@@ -174,7 +197,7 @@ func TestParser(t *testing.T) { ...@@ -174,7 +197,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -189,9 +212,9 @@ func TestParser(t *testing.T) { ...@@ -189,9 +212,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "New York", "city": "New York",
}, }),
}, },
}, },
}, },
...@@ -213,19 +236,19 @@ func TestParser(t *testing.T) { ...@@ -213,19 +236,19 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}, }),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -240,19 +263,19 @@ func TestParser(t *testing.T) { ...@@ -240,19 +263,19 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}, }),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -267,17 +290,17 @@ func TestParser(t *testing.T) { ...@@ -267,17 +290,17 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}, }),
}, },
}, },
}, },
...@@ -292,16 +315,16 @@ func TestParser(t *testing.T) { ...@@ -292,16 +315,16 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -316,9 +339,9 @@ func TestParser(t *testing.T) { ...@@ -316,9 +339,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "Tokyo", "city": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -347,9 +370,9 @@ func TestParser(t *testing.T) { ...@@ -347,9 +370,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "Tokyo", "city": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -371,9 +394,9 @@ func TestParser(t *testing.T) { ...@@ -371,9 +394,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "Tokyo", "city": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -453,18 +476,18 @@ func TestParser(t *testing.T) { ...@@ -453,18 +476,18 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"city": "London", "city": "London",
}, }),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -486,9 +509,9 @@ func TestParser(t *testing.T) { ...@@ -486,9 +509,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -528,9 +551,9 @@ func TestParser(t *testing.T) { ...@@ -528,9 +551,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, }),
}, },
}, },
}, },
...@@ -563,7 +586,7 @@ func TestParser(t *testing.T) { ...@@ -563,7 +586,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -591,14 +614,14 @@ func TestParser(t *testing.T) { ...@@ -591,14 +614,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "say_hello", Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -624,14 +647,14 @@ func TestParser(t *testing.T) { ...@@ -624,14 +647,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -648,7 +671,7 @@ func TestParser(t *testing.T) { ...@@ -648,7 +671,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -665,7 +688,7 @@ func TestParser(t *testing.T) { ...@@ -665,7 +688,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{}, Arguments: api.NewToolCallFunctionArguments(),
}, },
}, },
}, },
...@@ -687,9 +710,9 @@ func TestParser(t *testing.T) { ...@@ -687,9 +710,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_address", Name: "get_address",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "London", "location": "London",
}, }),
}, },
}, },
}, },
...@@ -706,9 +729,9 @@ func TestParser(t *testing.T) { ...@@ -706,9 +729,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_address", Name: "get_address",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"location": "London", "location": "London",
}, }),
}, },
}, },
}, },
...@@ -725,10 +748,10 @@ func TestParser(t *testing.T) { ...@@ -725,10 +748,10 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "add", Name: "add",
Arguments: api.ToolCallFunctionArguments{ Arguments: testArgs(map[string]any{
"a": "5", "a": "5",
"b": "10", "b": "10",
}, }),
}, },
}, },
}, },
...@@ -756,7 +779,7 @@ func TestParser(t *testing.T) { ...@@ -756,7 +779,7 @@ func TestParser(t *testing.T) {
} }
for i, want := range tt.calls { for i, want := range tt.calls {
if diff := cmp.Diff(calls[i], want); diff != "" { if diff := cmp.Diff(calls[i], want, argsComparer); diff != "" {
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff) t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
} }
} }
...@@ -1316,7 +1339,7 @@ func TestFindArguments(t *testing.T) { ...@@ -1316,7 +1339,7 @@ func TestFindArguments(t *testing.T) {
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer) got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
if diff := cmp.Diff(got, tt.want); diff != "" { if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff) t.Errorf("findArguments() args mismatch (-got +want):\n%s", diff)
} }
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment