convert_test.go 2.47 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package convert

import (
4
5
6
7
8
9
10
	"crypto/sha256"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"log/slog"
	"math"
Michael Yang's avatar
Michael Yang committed
11
12
	"os"
	"path/filepath"
13
	"slices"
Michael Yang's avatar
Michael Yang committed
14
15
16
	"testing"

	"github.com/ollama/ollama/llm"
17
	"golang.org/x/exp/maps"
Michael Yang's avatar
Michael Yang committed
18
19
)

20
func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
Michael Yang's avatar
Michael Yang committed
21
22
23
24
25
26
27
28
	t.Helper()

	f, err := os.CreateTemp(t.TempDir(), "f16")
	if err != nil {
		t.Fatal(err)
	}
	defer f.Close()

Michael Yang's avatar
Michael Yang committed
29
	if err := Convert(d, f); err != nil {
Michael Yang's avatar
Michael Yang committed
30
31
32
33
34
35
36
		t.Fatal(err)
	}

	r, err := os.Open(f.Name())
	if err != nil {
		t.Fatal(err)
	}
37
	t.Cleanup(func() { r.Close() })
Michael Yang's avatar
Michael Yang committed
38

39
	m, _, err := llm.DecodeGGML(r, math.MaxInt)
Michael Yang's avatar
Michael Yang committed
40
41
42
43
	if err != nil {
		t.Fatal(err)
	}

44
45
46
47
48
49
50
51
52
53
54
55
56
	if _, err := r.Seek(0, io.SeekStart); err != nil {
		t.Fatal(err)
	}

	return r, m.KV(), m.Tensors()
}

func TestMain(m *testing.M) {
	var level slog.Level
	flag.TextVar(&level, "level", slog.LevelInfo, "log level")
	flag.Parse()
	slog.SetLogLoggerLevel(level)
	os.Exit(m.Run())
Michael Yang's avatar
Michael Yang committed
57
58
59
}

func TestConvertFull(t *testing.T) {
60
61
62
63
64
	cases := []string{
		"Meta-Llama-3-8B-Instruct",
		"Mistral-7B-Instruct-v0.2",
		"Mixtral-8x7B-Instruct-v0.1",
		"gemma-2b-it",
Michael Yang's avatar
Michael Yang committed
65
66
	}

67
68
69
70
71
72
73
74
75
	for i := range cases {
		tt := cases[i]
		t.Run(tt, func(t *testing.T) {
			t.Parallel()

			p := filepath.Join("testdata", tt)
			if testing.Short() {
				t.Skip("skipping in short mode")
			} else if _, err := os.Stat(p); err != nil {
Michael Yang's avatar
Michael Yang committed
76
77
78
				t.Skipf("%s not found", p)
			}

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
			f, kv, tensors := convertFull(t, p)
			actual := make(map[string]string)
			for k, v := range kv {
				if s, ok := v.(json.Marshaler); !ok {
					actual[k] = fmt.Sprintf("%v", v)
				} else {
					bts, err := json.Marshal(s)
					if err != nil {
						t.Fatal(err)
					}

					actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
				}
			}

			for _, tensor := range tensors.Items {
				sha256sum := sha256.New()
				sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
				if _, err := io.Copy(sha256sum, sr); err != nil {
					t.Fatal(err)
				}
Michael Yang's avatar
Michael Yang committed
100

101
				actual[tensor.Name] = fmt.Sprintf("%x", sha256sum.Sum(nil))
Michael Yang's avatar
Michael Yang committed
102
103
			}

104
105
106
			expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
			if err != nil {
				t.Fatal(err)
Michael Yang's avatar
Michael Yang committed
107
108
			}

109
110
111
			var expect map[string]string
			if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
				t.Fatal(err)
Michael Yang's avatar
Michael Yang committed
112
113
			}

114
115
116
117
118
119
120
121
			keys := maps.Keys(expect)
			slices.Sort(keys)
			for _, k := range keys {
				if v, ok := actual[k]; !ok {
					t.Errorf("missing %s", k)
				} else if v != expect[k] {
					t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
				}
Michael Yang's avatar
Michael Yang committed
122
123
124
125
			}
		})
	}
}