quantization_test.go 3.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//go:build integration && models

package integration

import (
	"bytes"
	"context"
	"fmt"
	"log/slog"
	"strings"
	"testing"
	"time"

	"github.com/ollama/ollama/api"
)

func TestQuantization(t *testing.T) {
	sourceModels := []string{
		"qwen2.5:0.5b-instruct-fp16",
	}
	quantizations := []string{
		"Q8_0",
		"Q4_K_S",
		"Q4_K_M",
		"Q4_K",
	}
	softTimeout, hardTimeout := getTimeouts(t)
	started := time.Now()
	slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
	ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
	defer cancel()
	client, _, cleanup := InitServerConnection(ctx, t)
	defer cleanup()

	for _, base := range sourceModels {
		if err := PullIfMissing(ctx, client, base); err != nil {
			t.Fatalf("pull failed %s", err)
		}
		for _, quant := range quantizations {
			newName := fmt.Sprintf("%s__%s", base, quant)
			t.Run(newName, func(t *testing.T) {
				if time.Now().Sub(started) > softTimeout {
					t.Skip("skipping remaining tests to avoid excessive runtime")
				}
				req := &api.CreateRequest{
					Model:        newName,
					Quantization: quant,
					From:         base,
				}
				fn := func(resp api.ProgressResponse) error {
					// fmt.Print(".")
					return nil
				}
				t.Logf("quantizing: %s -> %s", base, quant)
				if err := client.Create(ctx, req, fn); err != nil {
					t.Fatalf("create failed %s", err)
				}
				defer func() {
					req := &api.DeleteRequest{
						Model: newName,
					}
					t.Logf("deleting: %s -> %s", base, quant)
					if err := client.Delete(ctx, req); err != nil {
						t.Logf("failed to clean up %s: %s", req.Model, err)
					}
				}()
				// Check metadata on the model
				resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
				if err != nil {
					t.Fatalf("unable to show model: %s", err)
				}
				if !strings.Contains(resp.Details.QuantizationLevel, quant) {
					t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
				}

				stream := true
				genReq := api.GenerateRequest{
					Model:     newName,
					Prompt:    "why is the sky blue?",
					KeepAlive: &api.Duration{Duration: 3 * time.Second},
					Options: map[string]any{
						"seed":        42,
						"temperature": 0.0,
					},
					Stream: &stream,
				}
				t.Logf("verifying: %s -> %s", base, quant)

				// Some smaller quantizations can cause models to have poor quality
				// or get stuck in repetition loops, so we stop as soon as we have any matches
				anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
				reqCtx, reqCancel := context.WithCancel(ctx)
				atLeastOne := false
				var buf bytes.Buffer
				genfn := func(response api.GenerateResponse) error {
					buf.Write([]byte(response.Response))
					fullResp := strings.ToLower(buf.String())
					for _, resp := range anyResp {
						if strings.Contains(fullResp, resp) {
							atLeastOne = true
							t.Log(fullResp)
							reqCancel()
							break
						}
					}
					return nil
				}

				done := make(chan int)
				var genErr error
				go func() {
					genErr = client.Generate(reqCtx, &genReq, genfn)
					done <- 0
				}()

				select {
				case <-done:
					if genErr != nil && !atLeastOne {
						t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
					}
				case <-ctx.Done():
					t.Error("outer test context done while waiting for generate")
				}

				t.Logf("passed")

			})
		}
	}
}