Unverified Commit 73444b7b authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Performance fix MistralTokenizer: cache special ids and tokens (#27925)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 853a8eb5
...@@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase): ...@@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase):
# Sort the dict for convenience # Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)
# Vocab sorted by token id. # Vocab sorted by token id.
self._vocab = self.tokenizer._vocab self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1 self._max_token_id = self.vocab_size - 1
...@@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase):
) )
) )
# the following attributes are set to fit vLLM's design and are used def _get_special_token_ids(self) -> list[int]:
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]
@property
def all_special_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
) )
...@@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase): ...@@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase):
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids) return sorted(special_ids)
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
]
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
return self._special_tokens
@property
def all_special_ids(self) -> list[int]:
return self._special_token_ids
@property @property
def bos_token_id(self) -> int: def bos_token_id(self) -> int:
return self.tokenizer.bos_id return self.tokenizer.bos_id
...@@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase):
raise NotImplementedError() raise NotImplementedError()
def _is_special_token_id(self, token_id: int) -> bool: def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import ( return token_id in self._special_token_ids_set
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
...@@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase):
tokens = [ tokens = [
t t
for t in tokens for t in tokens
if (t in to_decode_special_tokens or t not in self.all_special_tokens) if (t in to_decode_special_tokens or t not in self._special_tokens_set)
] ]
if any(isinstance(t, bytes) for t in tokens): if any(isinstance(t, bytes) for t in tokens):
...@@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase):
# We filtered unwanted special tokens so we can decode the rest. # We filtered unwanted special tokens so we can decode the rest.
tokens = [ tokens = [
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
if token_id not in self.all_special_ids if token_id not in self._special_token_ids_set
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept for token_id in ids_kept
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment