convert_test.go 2.49 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package convert

import (
4
5
6
7
8
	"crypto/sha256"
	"encoding/json"
	"flag"
	"fmt"
	"io"
9
	"io/fs"
10
11
	"log/slog"
	"math"
Michael Yang's avatar
Michael Yang committed
12
13
	"os"
	"path/filepath"
14
	"slices"
Michael Yang's avatar
Michael Yang committed
15
16
17
	"testing"

	"github.com/ollama/ollama/llm"
18
	"golang.org/x/exp/maps"
Michael Yang's avatar
Michael Yang committed
19
20
)

21
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
Michael Yang's avatar
Michael Yang committed
22
23
24
25
26
27
28
29
	t.Helper()

	f, err := os.CreateTemp(t.TempDir(), "f16")
	if err != nil {
		t.Fatal(err)
	}
	defer f.Close()

30
	if err := Convert(fsys, f); err != nil {
Michael Yang's avatar
Michael Yang committed
31
32
33
34
35
36
37
		t.Fatal(err)
	}

	r, err := os.Open(f.Name())
	if err != nil {
		t.Fatal(err)
	}
38
	t.Cleanup(func() { r.Close() })
Michael Yang's avatar
Michael Yang committed
39

40
	m, _, err := llm.DecodeGGML(r, math.MaxInt)
Michael Yang's avatar
Michael Yang committed
41
42
43
44
	if err != nil {
		t.Fatal(err)
	}

45
46
47
48
49
50
51
52
53
54
55
56
57
	if _, err := r.Seek(0, io.SeekStart); err != nil {
		t.Fatal(err)
	}

	return r, m.KV(), m.Tensors()
}

func TestMain(m *testing.M) {
	var level slog.Level
	flag.TextVar(&level, "level", slog.LevelInfo, "log level")
	flag.Parse()
	slog.SetLogLoggerLevel(level)
	os.Exit(m.Run())
Michael Yang's avatar
Michael Yang committed
58
59
60
}

func TestConvertFull(t *testing.T) {
61
62
63
64
65
	cases := []string{
		"Meta-Llama-3-8B-Instruct",
		"Mistral-7B-Instruct-v0.2",
		"Mixtral-8x7B-Instruct-v0.1",
		"gemma-2b-it",
Michael Yang's avatar
Michael Yang committed
66
67
	}

68
69
70
71
72
73
74
75
76
	for i := range cases {
		tt := cases[i]
		t.Run(tt, func(t *testing.T) {
			t.Parallel()

			p := filepath.Join("testdata", tt)
			if testing.Short() {
				t.Skip("skipping in short mode")
			} else if _, err := os.Stat(p); err != nil {
Michael Yang's avatar
Michael Yang committed
77
78
79
				t.Skipf("%s not found", p)
			}

80
			f, kv, tensors := convertFull(t, os.DirFS(p))
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
			actual := make(map[string]string)
			for k, v := range kv {
				if s, ok := v.(json.Marshaler); !ok {
					actual[k] = fmt.Sprintf("%v", v)
				} else {
					bts, err := json.Marshal(s)
					if err != nil {
						t.Fatal(err)
					}

					actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
				}
			}

			for _, tensor := range tensors.Items {
				sha256sum := sha256.New()
				sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
				if _, err := io.Copy(sha256sum, sr); err != nil {
					t.Fatal(err)
				}
Michael Yang's avatar
Michael Yang committed
101

102
				actual[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
Michael Yang's avatar
Michael Yang committed
103
104
			}

105
106
107
			expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
			if err != nil {
				t.Fatal(err)
Michael Yang's avatar
Michael Yang committed
108
109
			}

110
111
112
			var expect map[string]string
			if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
				t.Fatal(err)
Michael Yang's avatar
Michael Yang committed
113
114
			}

115
116
117
118
119
120
121
122
			keys := maps.Keys(expect)
			slices.Sort(keys)
			for _, k := range keys {
				if v, ok := actual[k]; !ok {
					t.Errorf("missing %s", k)
				} else if v != expect[k] {
					t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
				}
Michael Yang's avatar
Michael Yang committed
123
124
125
126
			}
		})
	}
}