library_models_test.go 2.36 KB
Newer Older
1
2
3
4
5
6
//go:build integration && library

package integration

import (
	"context"
7
	"fmt"
8
	"log/slog"
9
	"os"
10
11
12
13
14
15
16
17
	"testing"
	"time"

	"github.com/ollama/ollama/api"
)

// First run of this scenario on a target system will take a long time to download
// ~1.5TB of models.  Set a sufficiently large -timeout for your network speed
18
func TestLibraryModelsChat(t *testing.T) {
19
20
21
22
23
24
	softTimeout, hardTimeout := getTimeouts(t)
	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()
25
	targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
26
27
28
29
30
31
32
33
34
35

	chatModels := libraryChatModels
	for _, model := range chatModels {
		t.Run(model, func(t *testing.T) {
			if time.Now().Sub(started) > softTimeout {
				t.Skip("skipping remaining tests to avoid excessive runtime")
			}
			if err := PullIfMissing(ctx, client, model); err != nil {
				t.Fatalf("pull failed %s", err)
			}
36
37
38
39
40
41
42
43
44
45
			if targetArch != "" {
				resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
				if err != nil {
					t.Fatalf("unable to show model: %s", err)
				}
				arch := resp.ModelInfo["general.architecture"].(string)
				if arch != targetArch {
					t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
				}
			}
46
47
48
49
50
51
52
53
			req := api.ChatRequest{
				Model: model,
				Messages: []api.Message{
					{
						Role:    "user",
						Content: blueSkyPrompt,
					},
				},
54
55
56
57
58
59
				KeepAlive: &api.Duration{Duration: 10 * time.Second},
				Options: map[string]interface{}{
					"temperature": 0.1,
					"seed":        123,
				},
			}
60
			anyResp := blueSkyExpected
61
62
63
64
65
			// Special cases
			if model == "duckdb-nsql" {
				anyResp = []string{"select", "from"}
			} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
				anyResp = []string{"yes", "no", "safe", "unsafe"}
66
			} else if model == "openthinker" {
67
68
				anyResp = []string{"plugin", "im_sep", "components", "function call"}
			} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
69
				req.Messages[0].Content = "def fibonacci():"
70
71
				anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
			}
72
			DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
73
74
75
		})
	}
}