process_text_spm_test.go 2.97 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package model

import (
	"log/slog"
	"os"
	"path/filepath"
	"slices"
	"testing"

	"google.golang.org/protobuf/proto"

	"github.com/ollama/ollama/convert/sentencepiece"
)

func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
	t.Helper()

	bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
	if err != nil {
		t.Fatal(err)
	}

	var spm sentencepiece.ModelProto
	if err := proto.Unmarshal(bts, &spm); err != nil {
		t.Fatal(err)
	}

	preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`

	var v Vocabulary

	for _, piece := range spm.GetPieces() {
		v.Values = append(v.Values, piece.GetPiece())
		v.Scores = append(v.Scores, piece.GetScore())
		switch t := piece.GetType(); t {
		case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
			sentencepiece.ModelProto_SentencePiece_CONTROL,
			sentencepiece.ModelProto_SentencePiece_UNUSED,
			sentencepiece.ModelProto_SentencePiece_BYTE:
			v.Types = append(v.Types, uint32(t))
		default:
			tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
			// todo parse the special tokens file
			//   - this will roundtrip correctly but the <start_of_turn> and
			//     <end_of_turn> tokens aren't processed
			v.Types = append(v.Types, tt)
		}
	}

	return NewSentencePieceModel(preTokenizer, &v)
}

func TestSentencePieceEncode(t *testing.T) {
	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
	slog.SetDefault(logger)

	tokenizer := loadSentencePieceVocab(t)

	t.Run("basic roundtrip", func(t *testing.T) {
		t.Parallel()

		cases := []string{
			"hello",
			"hello ",
			"hello  ",
			" hello",
			" hello ",
			" hello  ",
			"hello world",
			"请考试我的软件!12345",
			"你好",
			"Hello 你好 world!",
73
74
75
76
77
78
79
80
			"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
			"Multilingual: 你好 こんにちは Привет Hola مرحبا",
			"Numbers and symbols: 123456789 +- */",
			"Special tokens: <bos> text <eos>",
			"Code snippets: func main() { fmt.Println(\"Hello World\") }",
			"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
				"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
				"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
Patrick Devine's avatar
Patrick Devine committed
81
82
83
		}

		for _, want := range cases {
Jesse Gross's avatar
Jesse Gross committed
84
			ids, err := tokenizer.Encode(want, true)
Patrick Devine's avatar
Patrick Devine committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
			if err != nil {
				t.Fatal(err)
			}

			if got, err := tokenizer.Decode(ids); err != nil {
				t.Fatal(err)
			} else if got != want {
				t.Errorf("got %q, want %q [%#v]", got, want, ids)
			}
		}
	})

	t.Run("special tokens", func(t *testing.T) {
		type candidate struct {
			token string
			ids   []int32
		}

		cases := []candidate{
			{"<bos>", []int32{2}},
			{"<eos>", []int32{1}},
		}

		for _, want := range cases {
Jesse Gross's avatar
Jesse Gross committed
109
			ids, err := tokenizer.Encode(want.token, true)
Patrick Devine's avatar
Patrick Devine committed
110
111
112
113
114
115
116
117
118
			if err != nil {
				t.Fatal(err)
			}
			if !slices.Equal(ids, want.ids) {
				t.Errorf("got %#v, want %#v", ids, want.ids)
			}
		}
	})
}