convert_test.go 2.51 KB
Newer Older
mashun1's avatar
v1  
mashun1 committed
1
2
3
package convert

import (
xuxzh1's avatar
init  
xuxzh1 committed
4
5
6
7
8
9
10
11
12
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"io/fs"
	"log/slog"
	"math"
mashun1's avatar
v1  
mashun1 committed
13
14
	"os"
	"path/filepath"
xuxzh1's avatar
init  
xuxzh1 committed
15
	"slices"
mashun1's avatar
v1  
mashun1 committed
16
17
	"testing"

xuxzh1's avatar
init  
xuxzh1 committed
18
19
	"golang.org/x/exp/maps"

mashun1's avatar
v1  
mashun1 committed
20
21
22
	"github.com/ollama/ollama/llm"
)

xuxzh1's avatar
init  
xuxzh1 committed
23
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
mashun1's avatar
v1  
mashun1 committed
24
25
	t.Helper()

xuxzh1's avatar
init  
xuxzh1 committed
26
	f, err := os.CreateTemp(t.TempDir(), "f16")
mashun1's avatar
v1  
mashun1 committed
27
28
29
	if err != nil {
		t.Fatal(err)
	}
xuxzh1's avatar
init  
xuxzh1 committed
30
	defer f.Close()
mashun1's avatar
v1  
mashun1 committed
31

xuxzh1's avatar
init  
xuxzh1 committed
32
	if err := Convert(fsys, f); err != nil {
mashun1's avatar
v1  
mashun1 committed
33
34
35
		t.Fatal(err)
	}

xuxzh1's avatar
init  
xuxzh1 committed
36
	r, err := os.Open(f.Name())
mashun1's avatar
v1  
mashun1 committed
37
38
39
	if err != nil {
		t.Fatal(err)
	}
xuxzh1's avatar
init  
xuxzh1 committed
40
	t.Cleanup(func() { r.Close() })
mashun1's avatar
v1  
mashun1 committed
41

xuxzh1's avatar
init  
xuxzh1 committed
42
	m, _, err := llm.DecodeGGML(r, math.MaxInt)
mashun1's avatar
v1  
mashun1 committed
43
44
45
46
	if err != nil {
		t.Fatal(err)
	}

xuxzh1's avatar
init  
xuxzh1 committed
47
	if _, err := r.Seek(0, io.SeekStart); err != nil {
mashun1's avatar
v1  
mashun1 committed
48
49
50
		t.Fatal(err)
	}

xuxzh1's avatar
init  
xuxzh1 committed
51
52
	return r, m.KV(), m.Tensors()
}
mashun1's avatar
v1  
mashun1 committed
53

xuxzh1's avatar
init  
xuxzh1 committed
54
55
56
57
58
59
func TestMain(m *testing.M) {
	var level slog.Level
	flag.TextVar(&level, "level", slog.LevelInfo, "log level")
	flag.Parse()
	slog.SetLogLoggerLevel(level)
	os.Exit(m.Run())
mashun1's avatar
v1  
mashun1 committed
60
61
62
}

func TestConvertFull(t *testing.T) {
xuxzh1's avatar
init  
xuxzh1 committed
63
64
65
66
67
	cases := []string{
		"Meta-Llama-3-8B-Instruct",
		"Mistral-7B-Instruct-v0.2",
		"Mixtral-8x7B-Instruct-v0.1",
		"gemma-2b-it",
mashun1's avatar
v1  
mashun1 committed
68
69
	}

xuxzh1's avatar
init  
xuxzh1 committed
70
71
72
73
74
75
76
77
78
	for i := range cases {
		tt := cases[i]
		t.Run(tt, func(t *testing.T) {
			t.Parallel()

			p := filepath.Join("testdata", tt)
			if testing.Short() {
				t.Skip("skipping in short mode")
			} else if _, err := os.Stat(p); err != nil {
mashun1's avatar
v1  
mashun1 committed
79
80
81
				t.Skipf("%s not found", p)
			}

xuxzh1's avatar
init  
xuxzh1 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
			f, kv, tensors := convertFull(t, os.DirFS(p))
			actual := make(map[string]string)
			for k, v := range kv {
				if s, ok := v.(json.Marshaler); !ok {
					actual[k] = fmt.Sprintf("%v", v)
				} else {
					bts, err := json.Marshal(s)
					if err != nil {
						t.Fatal(err)
					}

					actual[k] = fmt.Sprintf("%x", sha256.Sum256(bts))
				}
			}

			for _, tensor := range tensors.Items {
				sha256sum := sha256.New()
				sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
				if _, err := io.Copy(sha256sum, sr); err != nil {
					t.Fatal(err)
				}
mashun1's avatar
v1  
mashun1 committed
103

xuxzh1's avatar
init  
xuxzh1 committed
104
				actual[tensor.Name] = hex.EncodeToString(sha256sum.Sum(nil))
mashun1's avatar
v1  
mashun1 committed
105
106
			}

xuxzh1's avatar
init  
xuxzh1 committed
107
108
109
			expectFile, err := os.Open(filepath.Join("testdata", fmt.Sprintf("%s.json", tt)))
			if err != nil {
				t.Fatal(err)
mashun1's avatar
v1  
mashun1 committed
110
111
			}

xuxzh1's avatar
init  
xuxzh1 committed
112
113
114
			var expect map[string]string
			if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
				t.Fatal(err)
mashun1's avatar
v1  
mashun1 committed
115
116
			}

xuxzh1's avatar
init  
xuxzh1 committed
117
118
119
120
121
122
123
124
			keys := maps.Keys(expect)
			slices.Sort(keys)
			for _, k := range keys {
				if v, ok := actual[k]; !ok {
					t.Errorf("missing %s", k)
				} else if v != expect[k] {
					t.Errorf("unexpected %s: want %s, got %s", k, expect[k], v)
				}
mashun1's avatar
v1  
mashun1 committed
125
126
127
128
			}
		})
	}
}