process_text_spm_test.go 3.84 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package model

import (
	"log/slog"
	"os"
	"path/filepath"
	"slices"
	"testing"

	"google.golang.org/protobuf/proto"

	"github.com/ollama/ollama/convert/sentencepiece"
)

func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
	t.Helper()

	bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
	if err != nil {
		t.Fatal(err)
	}

	var spm sentencepiece.ModelProto
	if err := proto.Unmarshal(bts, &spm); err != nil {
		t.Fatal(err)
	}

	var v Vocabulary

	for _, piece := range spm.GetPieces() {
		v.Values = append(v.Values, piece.GetPiece())
		v.Scores = append(v.Scores, piece.GetScore())
		switch t := piece.GetType(); t {
		case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
			sentencepiece.ModelProto_SentencePiece_CONTROL,
			sentencepiece.ModelProto_SentencePiece_UNUSED,
			sentencepiece.ModelProto_SentencePiece_BYTE:
			v.Types = append(v.Types, uint32(t))
		default:
			tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
			// todo parse the special tokens file
			//   - this will roundtrip correctly but the <start_of_turn> and
			//     <end_of_turn> tokens aren't processed
			v.Types = append(v.Types, tt)
		}
	}

48
	return NewSentencePieceModel(&v)
Patrick Devine's avatar
Patrick Devine committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
}

func TestSentencePieceEncode(t *testing.T) {
	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
	slog.SetDefault(logger)

	tokenizer := loadSentencePieceVocab(t)

	t.Run("basic roundtrip", func(t *testing.T) {
		t.Parallel()

		cases := []string{
			"hello",
			"hello ",
			"hello  ",
			" hello",
			" hello ",
			" hello  ",
			"hello world",
			"请考试我的软件!12345",
			"你好",
			"Hello 你好 world!",
71
72
73
74
75
76
77
78
			"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
			"Multilingual: 你好 こんにちは Привет Hola مرحبا",
			"Numbers and symbols: 123456789 +- */",
			"Special tokens: <bos> text <eos>",
			"Code snippets: func main() { fmt.Println(\"Hello World\") }",
			"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
				"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
				"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
Patrick Devine's avatar
Patrick Devine committed
79
80
81
		}

		for _, want := range cases {
Jesse Gross's avatar
Jesse Gross committed
82
			ids, err := tokenizer.Encode(want, true)
Patrick Devine's avatar
Patrick Devine committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
			if err != nil {
				t.Fatal(err)
			}

			if got, err := tokenizer.Decode(ids); err != nil {
				t.Fatal(err)
			} else if got != want {
				t.Errorf("got %q, want %q [%#v]", got, want, ids)
			}
		}
	})

	t.Run("special tokens", func(t *testing.T) {
		type candidate struct {
			token string
			ids   []int32
		}

		cases := []candidate{
			{"<bos>", []int32{2}},
			{"<eos>", []int32{1}},
		}

		for _, want := range cases {
Jesse Gross's avatar
Jesse Gross committed
107
			ids, err := tokenizer.Encode(want.token, true)
Patrick Devine's avatar
Patrick Devine committed
108
109
110
111
112
113
114
115
116
			if err != nil {
				t.Fatal(err)
			}
			if !slices.Equal(ids, want.ids) {
				t.Errorf("got %#v, want %#v", ids, want.ids)
			}
		}
	})
}
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
	vocab := &Vocabulary{
		Values: []string{
			"normal",
			"<0xEA>",
			"<0x41>",
			"<0xC3>",
			"<0xA3>",
		},
		Types: []uint32{
			TOKEN_TYPE_NORMAL,
			TOKEN_TYPE_BYTE,
			TOKEN_TYPE_BYTE,
			TOKEN_TYPE_BYTE,
			TOKEN_TYPE_BYTE,
		},
		Scores: []float32{0, 0, 0, 0, 0},
	}

	spm := NewSentencePieceModel(vocab)

	tests := []struct {
		name     string
		ids      []int32
		expected string
	}{
		{
			name:     "single byte token",
			ids:      []int32{1},
			expected: "\xea",
		},
		{
			name:     "ASCII byte token",
			ids:      []int32{2},
			expected: "A",
		},
		{
			name:     "multiple byte tokens forming UTF-8 character",
			ids:      []int32{3, 4},
			expected: "ã",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			result, err := spm.Decode(tt.ids)
			if err != nil {
				t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
			}
			if result != tt.expected {
				t.Errorf("got %q, want %q", result, tt.expected)
			}
		})
	}
}