Unverified Commit 0535e5fe authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix edge case Mistral tokenizer (#10152)

parent b489fc3c
...@@ -72,11 +72,12 @@ class MistralTokenizer: ...@@ -72,11 +72,12 @@ class MistralTokenizer:
self.instruct = tokenizer.instruct_tokenizer self.instruct = tokenizer.instruct_tokenizer
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer): self.is_tekken = isinstance(tokenizer_, Tekkenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise # Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
elif self.is_spm:
elif isinstance(tokenizer_, SentencePieceTokenizer):
pass pass
else: else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
...@@ -218,7 +219,7 @@ class MistralTokenizer: ...@@ -218,7 +219,7 @@ class MistralTokenizer:
return encoded.tokens return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
if isinstance(self.tokenizer, Tekkenizer): if self.is_tekken:
tokens = [ tokens = [
t for t in tokens t for t in tokens
if t not in self.tokenizer._all_special_tokens if t not in self.tokenizer._all_special_tokens
...@@ -270,21 +271,20 @@ class MistralTokenizer: ...@@ -270,21 +271,20 @@ class MistralTokenizer:
skip_special_tokens skip_special_tokens
), "skip_special_tokens=False is not supported for Mistral tokenizers." ), "skip_special_tokens=False is not supported for Mistral tokenizers."
assert isinstance(self.tokenizer, assert self.is_tekken or self.is_spm, type(self.tokenizer)
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
if isinstance(self.tokenizer, Tekkenizer): if self.is_tekken:
# skip special tokens # skip special tokens
ids = [i for i in ids if i > self.tokenizer.num_special_tokens] ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
tokens = [self.tokenizer.id_to_piece(id) for id in ids] tokens = [self.tokenizer.id_to_piece(id) for id in ids]
if any("�" in t for t in tokens): if any("�" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the # if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes # token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640 # See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625 # https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "�"
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
return tokens return tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment