Commit 7916f550 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

vocab: Use int32 for special tokens

Special tokens are currently read as uint32 from the model metadata.
However, all other parts of the system (including the tokenizer) use
int32 to represent tokens so it is impossible to represent the high
portion of the unsigned range. For consistency and to avoid casts,
we should just use int32 everywhere.
parent d650ad39
...@@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) { ...@@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: c.Uint("tokenizer.ggml.eos_token_id"), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },
), ),
Layers: make([]Layer, c.Uint("block_count")), Layers: make([]Layer, c.Uint("block_count")),
......
...@@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) { ...@@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: c.Uint("tokenizer.ggml.eos_token_id"), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
......
...@@ -21,7 +21,7 @@ const ( ...@@ -21,7 +21,7 @@ const (
type TextProcessor interface { type TextProcessor interface {
Encode(string) ([]int32, error) Encode(string) ([]int32, error)
Decode([]int32) (string, error) Decode([]int32) (string, error)
Is(uint32, Special) bool Is(int32, Special) bool
} }
type Vocabulary struct { type Vocabulary struct {
...@@ -30,7 +30,7 @@ type Vocabulary struct { ...@@ -30,7 +30,7 @@ type Vocabulary struct {
Scores []uint32 Scores []uint32
Merges []string Merges []string
BOS, EOS uint32 BOS, EOS int32
specialOnce sync.Once specialOnce sync.Once
special []string special []string
...@@ -42,7 +42,7 @@ type Vocabulary struct { ...@@ -42,7 +42,7 @@ type Vocabulary struct {
merge map[string]int32 merge map[string]int32
} }
func (v *Vocabulary) Is(id uint32, special Special) bool { func (v *Vocabulary) Is(id int32, special Special) bool {
switch special { switch special {
case SpecialBOS: case SpecialBOS:
return id == v.BOS return id == v.BOS
...@@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { ...@@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
} }
} }
func (bpe BytePairEncoding) Is(id uint32, special Special) bool { func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special) return bpe.vocab.Is(id, special)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment