Unverified Commit 6e2d04e4 authored by Joshua Lochner's avatar Joshua Lochner Committed by GitHub
Browse files

Fix slow GemmaTokenizer and improve SPM slow -> fast conversion process (#32191)

* Remove user-defined tokens which can be obtained through merges

* Remove debug line

* formatting

* Refactor spm slow -> fast converter

* revert unnecessary refactor

* set comprehension

* remove test files

* Use `vocab_scores`

* Always replace spiece underline with space in decode

* we no longer need token filtering

* Add save fast load slow unit test

* Remove tokenizers version check

* Remove duplicate code

* Make `<start_of_turn>` and `<end_of_turn>` special tokens

* Bias merge priority with length if score is the same

* Add unit test for merge priority

* CI
parent 026a173a
...@@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str: ...@@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
return prepend_scheme return prepend_scheme
def generate_merges(vocab, vocab_scores):
reverse = vocab_scores is not None
vocab_scores = dict(vocab_scores) if reverse else vocab
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return merges
class SentencePieceExtractor: class SentencePieceExtractor:
""" """
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
...@@ -73,24 +92,8 @@ class SentencePieceExtractor: ...@@ -73,24 +92,8 @@ class SentencePieceExtractor:
sp = self.sp sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())} vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
if vocab_scores is not None: merges = generate_merges(vocab, vocab_scores)
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False
# Merges
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges return vocab, merges
...@@ -107,24 +110,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor): ...@@ -107,24 +110,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
# "<0x09>" is the bytefallback for `\t` # "<0x09>" is the bytefallback for `\t`
vocab["\t"] = vocab.get("<0x09>") vocab["\t"] = vocab.get("<0x09>")
if vocab_scores is not None: merges = generate_merges(vocab, vocab_scores)
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False
# Merges
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges return vocab, merges
...@@ -544,6 +530,10 @@ class DebertaConverter(Converter): ...@@ -544,6 +530,10 @@ class DebertaConverter(Converter):
class SpmConverter(Converter): class SpmConverter(Converter):
handle_byte_fallback = False
SpmExtractor = SentencePieceExtractor
special_tokens = {}
def __init__(self, *args): def __init__(self, *args):
requires_backends(self, "protobuf") requires_backends(self, "protobuf")
...@@ -557,14 +547,13 @@ class SpmConverter(Converter): ...@@ -557,14 +547,13 @@ class SpmConverter(Converter):
m.ParseFromString(f.read()) m.ParseFromString(f.read())
self.proto = m self.proto = m
if self.proto.trainer_spec.byte_fallback: if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
if not getattr(self, "handle_byte_fallback", None): warnings.warn(
warnings.warn( "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the" " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these " "unknown tokens into a sequence of byte tokens matching the original piece of text."
"unknown tokens into a sequence of byte tokens matching the original piece of text." )
)
def vocab(self, proto): def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces] return [(piece.piece, piece.score) for piece in proto.pieces]
...@@ -575,12 +564,18 @@ class SpmConverter(Converter): ...@@ -575,12 +564,18 @@ class SpmConverter(Converter):
def tokenizer(self, proto): def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto) vocab_scores = self.vocab(proto)
unk_id = self.unk_id(proto)
if model_type == 1: if model_type == 1:
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id)) tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.unk_id(proto),
byte_fallback=self.handle_byte_fallback,
)
)
elif model_type == 2: elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract() _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer( tokenizer = Tokenizer(
BPE( BPE(
...@@ -588,13 +583,53 @@ class SpmConverter(Converter): ...@@ -588,13 +583,53 @@ class SpmConverter(Converter):
merges, merges,
unk_token=proto.trainer_spec.unk_piece, unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True, fuse_unk=True,
byte_fallback=self.handle_byte_fallback,
dropout=None,
) )
) )
else: else:
raise Exception( raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm" "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
) )
# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
spm_added_tokens = [
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
for id, p in enumerate(proto.pieces)
if p.type in [3, 4]
]
tokens_to_add = [
AddedToken(token, normalized=False, special=special)
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]
if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token in tokens_to_add:
is_special = token.special
if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
tokens = [token]
is_last_special = is_special
if tokens:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
return tokenizer return tokenizer
def normalizer(self, proto): def normalizer(self, proto):
...@@ -622,40 +657,6 @@ class SpmConverter(Converter): ...@@ -622,40 +657,6 @@ class SpmConverter(Converter):
def converted(self) -> Tokenizer: def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto) tokenizer = self.tokenizer(self.proto)
# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
tokens_to_add = {
id: AddedToken(token, normalized=False, special=special)
for id, token, special in [
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
]
}
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
if len(tokens_to_add) > 0:
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
# individual tokens would repeatedly rebuild a trie, which can be slow.
is_last_special = None
tokens = []
for token in tokens_to_add:
is_special = token.special
if is_last_special is None or is_last_special == is_special:
tokens.append(token)
else:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
tokens = [token]
is_last_special = is_special
if tokens:
if is_last_special:
tokenizer.add_special_tokens(tokens)
else:
tokenizer.add_tokens(tokens)
# Tokenizer assemble # Tokenizer assemble
normalizer = self.normalizer(self.proto) normalizer = self.normalizer(self.proto)
if normalizer is not None: if normalizer is not None:
...@@ -1283,6 +1284,9 @@ class XGLMConverter(SpmConverter): ...@@ -1283,6 +1284,9 @@ class XGLMConverter(SpmConverter):
class GemmaConvert(SpmConverter): class GemmaConvert(SpmConverter):
handle_byte_fallback = True handle_byte_fallback = True
SpmExtractor = GemmaSentencePieceExtractor
# start and end of turn tokens must be marked as special
special_tokens = {"<start_of_turn>", "<end_of_turn>"}
"""" """"
split_by_unicode_script: true split_by_unicode_script: true
...@@ -1327,45 +1331,6 @@ class GemmaConvert(SpmConverter): ...@@ -1327,45 +1331,6 @@ class GemmaConvert(SpmConverter):
] ]
) )
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
elif model_type == 2:
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=True,
dropout=None,
)
)
tokenizer.add_special_tokens(
[
AddedToken("<pad>", normalized=False, special=True),
AddedToken("<eos>", normalized=False, special=True),
AddedToken("<bos>", normalized=False, special=True),
AddedToken("<unk>", normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
class LlamaConverter(SpmConverter): class LlamaConverter(SpmConverter):
handle_byte_fallback = True handle_byte_fallback = True
...@@ -1393,37 +1358,6 @@ class LlamaConverter(SpmConverter): ...@@ -1393,37 +1358,6 @@ class LlamaConverter(SpmConverter):
sequence += [decoders.Strip(content=" ", left=1)] sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence) return decoders.Sequence(sequence)
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto): def normalizer(self, proto):
if getattr(self.original_tokenizer, "legacy", True): if getattr(self.original_tokenizer, "legacy", True):
sequence = [] sequence = []
......
...@@ -198,7 +198,7 @@ class GemmaTokenizer(PreTrainedTokenizer): ...@@ -198,7 +198,7 @@ class GemmaTokenizer(PreTrainedTokenizer):
else: else:
sub_texts = "".join(sub_texts) sub_texts = "".join(sub_texts)
return sub_texts return sub_texts.replace(SPIECE_UNDERLINE, " ")
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
......
...@@ -222,6 +222,17 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -222,6 +222,17 @@ class GemmaIntegrationTest(unittest.TestCase):
self.tokenizer.add_eos_token = False self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False self.rust_tokenizer.add_eos_token = False
def test_fast_merge_priority(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
text = " "
target = [168, 153]
slow = slow_tokenizer.encode(text, add_special_tokens=False)
assert slow == target
fast = fast_tokenizer.encode(text, add_special_tokens=False)
assert fast == target
@unittest.skip(reason="Not super important and always failing. Let's skip it") @unittest.skip(reason="Not super important and always failing. Let's skip it")
@slow @slow
def test_conversion(self): def test_conversion(self):
...@@ -442,6 +453,30 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -442,6 +453,30 @@ class GemmaIntegrationTest(unittest.TestCase):
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens) self.assertListEqual(tokenized_chat, expected_tokens)
def test_save_fast_load_slow(self):
# Ensure that we can save a fast tokenizer and load it as a slow tokenizer
slow_tokenizer = self.tokenizer
text = "a "
target_encoded = [2, 235250, 139]
slow = slow_tokenizer.encode(text, add_special_tokens=True)
assert slow == target_encoded
slow_decoded = slow_tokenizer.decode(slow, skip_special_tokens=True)
assert slow_decoded == text
with tempfile.TemporaryDirectory() as dirname:
# Save fast tokenizer
self.rust_tokenizer.save_pretrained(dirname)
# Load slow tokenizer with fast files present in the directory
slow_tokenizer_from_fast = GemmaTokenizer.from_pretrained(dirname)
slow_from_fast = slow_tokenizer_from_fast.encode(text, add_special_tokens=True)
assert slow_from_fast == target_encoded
slow_from_fast_decoded = slow_tokenizer_from_fast.decode(slow, skip_special_tokens=True)
assert slow_from_fast_decoded == text
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment