quantization_test.go 3.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
//go:build integration && models

package integration

import (
	"bytes"
	"context"
	"fmt"
	"log/slog"
	"strings"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
)

func TestQuantization(t *testing.T) {
	sourceModels := []string{
		"qwen2.5:0.5b-instruct-fp16",
	}
	quantizations := []string{
		"Q8_0",
		"Q4_K_S",
		"Q4_K_M",
		"Q4_K",
	}
	softTimeout, hardTimeout := getTimeouts(t)
	started := time.Now()
	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	for _, base := range sourceModels {
		if err := PullIfMissing(ctx, client, base); err != nil {
			t.Fatalf("pull failed %s", err)
		}
		for _, quant := range quantizations {
			newName := fmt.Sprintf("%s__%s", base, quant)
			t.Run(newName, func(t *testing.T) {
				if time.Now().Sub(started) > softTimeout {
					t.Skip("skipping remaining tests to avoid excessive runtime")
				}
				req := &api.CreateRequest{
					Model:        newName,
					Quantization: quant,
					From:         base,
				}
				fn := func(resp api.ProgressResponse) error {
					// fmt.Print(".")
					return nil
				}
				t.Logf("quantizing: %s -> %s", base, quant)
				if err := client.Create(ctx, req, fn); err != nil {
					t.Fatalf("create failed %s", err)
				}
				defer func() {
					req := &api.DeleteRequest{
						Model: newName,
					}
					t.Logf("deleting: %s -> %s", base, quant)
					if err := client.Delete(ctx, req); err != nil {
						t.Logf("failed to clean up %s: %s", req.Model, err)
					}
				}()
				// Check metadata on the model
				resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
				if err != nil {
					t.Fatalf("unable to show model: %s", err)
				}
				if !strings.Contains(resp.Details.QuantizationLevel, quant) {
					t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
				}

				stream := true
77
78
79
80
81
82
83
84
				chatReq := api.ChatRequest{
					Model: newName,
					Messages: []api.Message{
						{
							Role:    "user",
							Content: blueSkyPrompt,
						},
					},
85
86
87
88
89
90
91
92
93
94
95
96
97
98
					KeepAlive: &api.Duration{Duration: 3 * time.Second},
					Options: map[string]any{
						"seed":        42,
						"temperature": 0.0,
					},
					Stream: &stream,
				}
				t.Logf("verifying: %s -> %s", base, quant)

				// Some smaller quantizations can cause models to have poor quality
				// or get stuck in repetition loops, so we stop as soon as we have any matches
				reqCtx, reqCancel := context.WithCancel(ctx)
				atLeastOne := false
				var buf bytes.Buffer
99
100
				chatfn := func(response api.ChatResponse) error {
					buf.Write([]byte(response.Message.Content))
101
					fullResp := strings.ToLower(buf.String())
102
					for _, resp := range blueSkyExpected {
103
104
105
106
107
108
109
110
111
112
113
114
115
						if strings.Contains(fullResp, resp) {
							atLeastOne = true
							t.Log(fullResp)
							reqCancel()
							break
						}
					}
					return nil
				}

				done := make(chan int)
				var genErr error
				go func() {
116
					genErr = client.Chat(reqCtx, &chatReq, chatfn)
117
118
119
120
121
122
					done <- 0
				}()

				select {
				case <-done:
					if genErr != nil && !atLeastOne {
123
						t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
124
125
126
127
128
129
130
131
132
133
134
					}
				case <-ctx.Done():
					t.Error("outer test context done while waiting for generate")
				}

				t.Logf("passed")

			})
		}
	}
}