tokenizer_test.go 1.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package convert

import (
	"io"
	"io/fs"
	"os"
	"path/filepath"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
)

func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
	t.Helper()

	for k, v := range files {
		if err := func() error {
			f, err := os.Create(filepath.Join(dir, k))
			if err != nil {
				return err
			}
			defer f.Close()

			if _, err := io.Copy(f, v); err != nil {
				return err
			}

			return nil
		}(); err != nil {
			t.Fatalf("unexpected error: %v", err)
		}
	}

	return os.DirFS(dir)
}

func TestParseTokenizer(t *testing.T) {
	cases := []struct {
		name              string
		fsys              fs.FS
		specialTokenTypes []string
		want              *Tokenizer
	}{
		{
			name: "string chat template",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{}`),
				"tokenizer_config.json": strings.NewReader(`{
					"chat_template": "<default template>"
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{Model: "gpt2"},
				Pre:        "default",
				Template:   "<default template>",
			},
		},
		{
			name: "list chat template",
			fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
				"tokenizer.json": strings.NewReader(`{}`),
				"tokenizer_config.json": strings.NewReader(`{
					"chat_template": [
						{
							"name": "default",
							"template": "<default template>"
						},
						{
							"name": "tools",
							"template": "<tools template>"
						}
					]
				}`),
			}),
			want: &Tokenizer{
				Vocabulary: &Vocabulary{Model: "gpt2"},
				Pre:        "default",
				Template:   "<default template>",
			},
		},
	}

	for _, tt := range cases {
		t.Run(tt.name, func(t *testing.T) {
			tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
				t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
			}
		})
	}
}