Unverified Commit 93a126fb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Make cached tokenizer pickle-compatible (#17048)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8e4b351a
...@@ -63,14 +63,16 @@ class Request: ...@@ -63,14 +63,16 @@ class Request:
output_len: int output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: def sample_tokens(tokenizer: PreTrainedTokenizerBase,
length: int) -> list[int]:
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
# Remove the special tokens. # Remove the special tokens.
vocab = { return random.choices(
k: v [v for k, v in vocab.items() if k not in all_special_ids],
for k, v in vocab.items() if k not in tokenizer.all_special_ids k=length,
} )
return random.choices(list(vocab.values()), k=length)
def sample_requests_from_dataset( def sample_requests_from_dataset(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pickle
from copy import deepcopy from copy import deepcopy
import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
def test_cached_tokenizer(): @pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens( reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]}) {"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)
def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode( assert target.encode("prompt") == expected.encode("prompt")
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import contextlib import contextlib
import copy
import os import os
import warnings import warnings
from functools import lru_cache from functools import lru_cache
...@@ -70,18 +71,17 @@ def encode_tokens( ...@@ -70,18 +71,17 @@ def encode_tokens(
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties. """
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This each time they are called, leading to a significant slowdown.
function caches these properties for faster access.""" This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = set(tokenizer.all_special_ids) tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_all_special_tokens_extended = ( tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended) tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab() tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
...@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: ...@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class CachedTokenizer(tokenizer.__class__): # type: ignore class CachedTokenizer(tokenizer.__class__): # type: ignore
@property @property
def all_special_ids(self): def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids return tokenizer_all_special_ids
@property @property
def all_special_tokens(self): def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens return tokenizer_all_special_tokens
@property @property
def all_special_tokens_extended(self): def all_special_tokens_extended(self) -> list[str]:
return tokenizer_all_special_tokens_extended return tokenizer_all_special_tokens_extended
@property @property
def max_token_id(self): def max_token_id(self) -> int:
return max_token_id return max_token_id
def get_vocab(self): def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab return tokenizer_vocab
def __len__(self): def __len__(self) -> int:
return tokenizer_len return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer, )
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer cached_tokenizer.__class__ = CachedTokenizer
return tokenizer return cached_tokenizer
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import importlib import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
...@@ -12,17 +12,17 @@ class TokenizerBase(ABC): ...@@ -12,17 +12,17 @@ class TokenizerBase(ABC):
@property @property
@abstractmethod @abstractmethod
def all_special_tokens_extended(self) -> List[str]: def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
@abstractmethod @abstractmethod
def all_special_tokens(self) -> List[str]: def all_special_tokens(self) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
@property @property
@abstractmethod @abstractmethod
def all_special_ids(self) -> List[int]: def all_special_ids(self) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@property @property
...@@ -66,7 +66,7 @@ class TokenizerBase(ABC): ...@@ -66,7 +66,7 @@ class TokenizerBase(ABC):
@abstractmethod @abstractmethod
def __call__( def __call__(
self, self,
text: Union[str, List[str], List[int]], text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None, text_pair: Optional[str] = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
...@@ -75,11 +75,11 @@ class TokenizerBase(ABC): ...@@ -75,11 +75,11 @@ class TokenizerBase(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_added_vocab(self) -> Dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
...@@ -88,44 +88,44 @@ class TokenizerBase(ABC): ...@@ -88,44 +88,44 @@ class TokenizerBase(ABC):
text: str, text: str,
truncation: bool = False, truncation: bool = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
) -> List[int]: ) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def encode(self, def encode(self,
text: str, text: str,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def apply_chat_template(self, def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> List[int]: **kwargs) -> list[int]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def decode(self, def decode(self,
ids: Union[List[int], int], ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str: skip_special_tokens: bool = True) -> str:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: List[int], ids: list[int],
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
) -> List[str]: ) -> list[str]:
raise NotImplementedError() raise NotImplementedError()
class TokenizerRegistry: class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class) # Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {} REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod @staticmethod
def register(name: str, module: str, class_name: str) -> None: def register(name: str, module: str, class_name: str) -> None:
......
...@@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase): ...@@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
# the following attributes are set to fit vLLM's design and are used # the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends. # by the guided structured output backends.
@property @property
def all_special_tokens_extended(self) -> List[str]: def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list # tekken defines its own extended special tokens list
...@@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase): ...@@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
] ]
@property @property
def all_special_tokens(self) -> List[str]: def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended return self.all_special_tokens_extended
@property @property
def all_special_ids(self) -> List[int]: def all_special_ids(self) -> list[int]:
return [ return [
self.all_special_tokens.index(t) for t in self.all_special_tokens self.all_special_tokens.index(t) for t in self.all_special_tokens
] ]
...@@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase): ...@@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
input_ids = self.encode_one(text, truncation, max_length) input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids) return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]: def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map # NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes # to the same string but have different bytes
return self._vocab_dict return self._vocab_dict
def get_added_vocab(self) -> Dict[str, int]: def get_added_vocab(self) -> dict[str, int]:
# Mistral tokenizers have no added vocabulary # Mistral tokenizers have no added vocabulary
return {} return {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment