tools_test.go 3.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//go:build integration

package integration

import (
	"context"
	"fmt"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
)

func TestAPIToolCalling(t *testing.T) {
	initialTimeout := 60 * time.Second
	streamTimeout := 60 * time.Second
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
	defer cancel()

	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	minVRAM := map[string]uint64{
		"qwen3-vl":      16,
		"gpt-oss:20b":   16,
		"gpt-oss:120b":  70,
		"qwen3":         6,
		"llama3.1":      8,
		"llama3.2":      4,
		"mistral":       6,
		"qwen2.5":       6,
		"qwen2":         6,
		"mistral-nemo":  9,
		"mistral-small": 16,
		"mixtral:8x22b": 80,
		"qwq":           20,
		"granite3.3":    7,
	}

	for _, model := range libraryToolsModels {
		t.Run(model, func(t *testing.T) {
			if v, ok := minVRAM[model]; ok {
				skipUnderMinVRAM(t, v)
			}

			if err := PullIfMissing(ctx, client, model); err != nil {
				t.Fatalf("pull failed %s", err)
			}

			tools := []api.Tool{
				{
					Type: "function",
					Function: api.ToolFunction{
						Name:        "get_weather",
						Description: "Get the current weather in a given location",
						Parameters: api.ToolFunctionParameters{
							Type:     "object",
							Required: []string{"location"},
							Properties: map[string]api.ToolProperty{
								"location": {
									Type:        api.PropertyType{"string"},
									Description: "The city and state, e.g. San Francisco, CA",
								},
							},
						},
					},
				},
			}

			req := api.ChatRequest{
				Model: model,
				Messages: []api.Message{
					{
						Role:    "user",
						Content: "Call get_weather with location set to San Francisco.",
					},
				},
				Tools: tools,
				Options: map[string]any{
					"temperature": 0,
				},
			}

			stallTimer := time.NewTimer(initialTimeout)
			var gotToolCall bool
			var lastToolCall api.ToolCall

			fn := func(response api.ChatResponse) error {
				if len(response.Message.ToolCalls) > 0 {
					gotToolCall = true
					lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
				}
				if !stallTimer.Reset(streamTimeout) {
					return fmt.Errorf("stall was detected while streaming response, aborting")
				}
				return nil
			}

			stream := true
			req.Stream = &stream
			done := make(chan int)
			var genErr error
			go func() {
				genErr = client.Chat(ctx, &req, fn)
				done <- 0
			}()

			select {
			case <-stallTimer.C:
				t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
			case <-done:
				if genErr != nil {
					t.Fatalf("chat failed: %v", genErr)
				}

				if !gotToolCall {
					t.Fatalf("expected at least one tool call, got none")
				}

				if lastToolCall.Function.Name != "get_weather" {
					t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
				}

				if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
					t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
				}
			case <-ctx.Done():
				t.Error("outer test context done while waiting for tool-calling chat")
			}
		})
	}
}