Commit b70fc4d5 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

model: Don't unconditionally add special tokens

We sometimes tokenize partial strings. For example, with
multimodal inputs, we split the input string around the images
and then tokenize each piece. In these cases, we should only add
the special tokens on the first piece.
parent e2252d0f
...@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) ...@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
return s.llamaModel.Tokenize(content, false, true) return s.llamaModel.Tokenize(content, false, true)
} }
if s.textProcessor != nil { if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content) tokens, err := s.textProcessor.Encode(content, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -19,7 +19,7 @@ const ( ...@@ -19,7 +19,7 @@ const (
) )
type TextProcessor interface { type TextProcessor interface {
Encode(string) ([]int32, error) Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error) Decode([]int32) (string, error)
Is(int32, Special) bool Is(int32, Special) bool
} }
...@@ -144,7 +144,7 @@ type merge struct { ...@@ -144,7 +144,7 @@ type merge struct {
runes []rune runes []rune
} }
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() { for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently // TODO: process special tokens concurrently
...@@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { ...@@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
} }
} }
if len(ids) > 0 { if addSpecial && len(ids) > 0 {
if bpe.vocab.AddBOS { if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS { if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
......
...@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) { ...@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
t.Parallel() t.Parallel()
ids, err := tokenizer.Encode("hello world") ids, err := tokenizer.Encode("hello world", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) { ...@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
t.Errorf("got %q, want hello world", s) t.Errorf("got %q, want hello world", s)
} }
ids, err = tokenizer.Encode("hello <|end_of_text|>") ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) { ...@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s) ids, err := tokenizer.Encode(s, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) { ...@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
} }
for _, want := range cases { for _, want := range cases {
ids, err := tokenizer.Encode(want) ids, err := tokenizer.Encode(want, true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) { ...@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s) ids, err := tokenizer.Encode(s, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { ...@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for range b.N { for range b.N {
_, err := tokenizer.Encode(string(bts)) _, err := tokenizer.Encode(string(bts), true)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
...@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { ...@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}) })
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts)) ids, err := tokenizer.Encode(string(bts), true)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
......
...@@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { ...@@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment