Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93a126fb
Unverified
Commit
93a126fb
authored
Apr 27, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 27, 2025
Browse files
[Misc] Make cached tokenizer pickle-compatible (#17048)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8e4b351a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
56 deletions
+80
-56
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+8
-6
tests/tokenization/test_cached_tokenizer.py
tests/tokenization/test_cached_tokenizer.py
+31
-12
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+19
-16
vllm/transformers_utils/tokenizer_base.py
vllm/transformers_utils/tokenizer_base.py
+17
-17
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+5
-5
No files found.
benchmarks/benchmark_prefix_caching.py
View file @
93a126fb
...
@@ -63,14 +63,16 @@ class Request:
...
@@ -63,14 +63,16 @@ class Request:
output_len
:
int
output_len
:
int
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
str
:
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
list
[
int
]:
vocab
=
tokenizer
.
get_vocab
()
vocab
=
tokenizer
.
get_vocab
()
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
# Remove the special tokens.
# Remove the special tokens.
vocab
=
{
return
random
.
choices
(
k
:
v
[
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
all_special_ids
],
for
k
,
v
in
vocab
.
items
()
if
k
not
in
tokenizer
.
all_special_ids
k
=
length
,
}
)
return
random
.
choices
(
list
(
vocab
.
values
()),
k
=
length
)
def
sample_requests_from_dataset
(
def
sample_requests_from_dataset
(
...
...
tests/tokenization/test_cached_tokenizer.py
View file @
93a126fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
copy
import
deepcopy
from
copy
import
deepcopy
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm.transformers_utils.tokenizer
import
get_cached_tokenizer
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
get_cached_tokenizer
)
def
test_cached_tokenizer
():
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"gpt2"
,
"THUDM/chatglm3-6b"
])
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
def
test_cached_tokenizer
(
model_id
:
str
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
trust_remote_code
=
True
)
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
(
reference_tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<SEP>"
]})
{
"additional_special_tokens"
:
[
"<SEP>"
]})
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
_check_consistency
(
cached_tokenizer
,
reference_tokenizer
)
pickled_tokenizer
=
pickle
.
dumps
(
cached_tokenizer
)
unpickled_tokenizer
=
pickle
.
loads
(
pickled_tokenizer
)
_check_consistency
(
unpickled_tokenizer
,
reference_tokenizer
)
def
_check_consistency
(
target
:
AnyTokenizer
,
expected
:
AnyTokenizer
):
assert
isinstance
(
target
,
type
(
expected
))
# Cached attributes
assert
target
.
all_special_ids
==
expected
.
all_special_ids
assert
target
.
all_special_tokens
==
expected
.
all_special_tokens
assert
(
target
.
all_special_tokens_extended
==
expected
.
all_special_tokens_extended
)
assert
target
.
get_vocab
()
==
expected
.
get_vocab
()
assert
len
(
target
)
==
len
(
expected
)
# Other attributes
assert
getattr
(
target
,
"padding_side"
,
None
)
==
getattr
(
expected
,
"padding_side"
,
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
cached_tokenizer
.
encode
(
assert
target
.
encode
(
"prompt"
)
==
expected
.
encode
(
"prompt"
)
"prompt"
)
assert
set
(
reference_tokenizer
.
all_special_ids
)
==
set
(
cached_tokenizer
.
all_special_ids
)
assert
set
(
reference_tokenizer
.
all_special_tokens
)
==
set
(
cached_tokenizer
.
all_special_tokens
)
assert
set
(
reference_tokenizer
.
all_special_tokens_extended
)
==
set
(
cached_tokenizer
.
all_special_tokens_extended
)
vllm/transformers_utils/tokenizer.py
View file @
93a126fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
contextlib
import
contextlib
import
copy
import
os
import
os
import
warnings
import
warnings
from
functools
import
lru_cache
from
functools
import
lru_cache
...
@@ -70,18 +71,17 @@ def encode_tokens(
...
@@ -70,18 +71,17 @@ def encode_tokens(
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
"""Get tokenizer with cached properties.
"""
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
each time they are called, leading to a significant slowdown.
function caches these properties for faster access."""
This proxy caches these properties for faster access.
"""
cached_tokenizer
=
copy
.
copy
(
tokenizer
)
tokenizer_all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
tokenizer_all_special_ids
=
tokenizer
.
all_special_ids
tokenizer_all_special_tokens
=
tokenizer
.
all_special_tokens
tokenizer_all_special_tokens_extended
=
(
tokenizer_all_special_tokens_extended
=
(
tokenizer
.
all_special_tokens_extended
)
tokenizer
.
all_special_tokens_extended
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_vocab
=
tokenizer
.
get_vocab
()
tokenizer_vocab
=
tokenizer
.
get_vocab
()
tokenizer_len
=
len
(
tokenizer
)
tokenizer_len
=
len
(
tokenizer
)
...
@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
...
@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
@
property
@
property
def
all_special_ids
(
self
):
def
all_special_ids
(
self
)
->
list
[
int
]
:
return
tokenizer_all_special_ids
return
tokenizer_all_special_ids
@
property
@
property
def
all_special_tokens
(
self
):
def
all_special_tokens
(
self
)
->
list
[
str
]
:
return
tokenizer_all_special_tokens
return
tokenizer_all_special_tokens
@
property
@
property
def
all_special_tokens_extended
(
self
):
def
all_special_tokens_extended
(
self
)
->
list
[
str
]
:
return
tokenizer_all_special_tokens_extended
return
tokenizer_all_special_tokens_extended
@
property
@
property
def
max_token_id
(
self
):
def
max_token_id
(
self
)
->
int
:
return
max_token_id
return
max_token_id
def
get_vocab
(
self
):
def
get_vocab
(
self
)
->
dict
[
str
,
int
]
:
return
tokenizer_vocab
return
tokenizer_vocab
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
tokenizer_len
return
tokenizer_len
def
__reduce__
(
self
):
return
get_cached_tokenizer
,
(
tokenizer
,
)
CachedTokenizer
.
__name__
=
f
"Cached
{
tokenizer
.
__class__
.
__name__
}
"
CachedTokenizer
.
__name__
=
f
"Cached
{
tokenizer
.
__class__
.
__name__
}
"
tokenizer
.
__class__
=
CachedTokenizer
cached_
tokenizer
.
__class__
=
CachedTokenizer
return
tokenizer
return
cached_
tokenizer
def
patch_padding_side
(
tokenizer
:
PreTrainedTokenizer
)
->
None
:
def
patch_padding_side
(
tokenizer
:
PreTrainedTokenizer
)
->
None
:
...
...
vllm/transformers_utils/tokenizer_base.py
View file @
93a126fb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
importlib
import
importlib
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
...
@@ -12,17 +12,17 @@ class TokenizerBase(ABC):
...
@@ -12,17 +12,17 @@ class TokenizerBase(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
all_special_tokens_extended
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens_extended
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
@
abstractmethod
@
abstractmethod
def
all_special_tokens
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
@
abstractmethod
@
abstractmethod
def
all_special_ids
(
self
)
->
L
ist
[
int
]:
def
all_special_ids
(
self
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
property
@
property
...
@@ -66,7 +66,7 @@ class TokenizerBase(ABC):
...
@@ -66,7 +66,7 @@ class TokenizerBase(ABC):
@
abstractmethod
@
abstractmethod
def
__call__
(
def
__call__
(
self
,
self
,
text
:
Union
[
str
,
L
ist
[
str
],
L
ist
[
int
]],
text
:
Union
[
str
,
l
ist
[
str
],
l
ist
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
...
@@ -75,11 +75,11 @@ class TokenizerBase(ABC):
...
@@ -75,11 +75,11 @@ class TokenizerBase(ABC):
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
get_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
get_added_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_added_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
...
@@ -88,44 +88,44 @@ class TokenizerBase(ABC):
...
@@ -88,44 +88,44 @@ class TokenizerBase(ABC):
text
:
str
,
text
:
str
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
encode
(
self
,
def
encode
(
self
,
text
:
str
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
L
ist
[
int
]:
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
apply_chat_template
(
self
,
def
apply_chat_template
(
self
,
messages
:
L
ist
[
"ChatCompletionMessageParam"
],
messages
:
l
ist
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
L
ist
[
D
ict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
l
ist
[
d
ict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
L
ist
[
int
]:
**
kwargs
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
convert_tokens_to_string
(
self
,
tokens
:
L
ist
[
str
])
->
str
:
def
convert_tokens_to_string
(
self
,
tokens
:
l
ist
[
str
])
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
decode
(
self
,
def
decode
(
self
,
ids
:
Union
[
L
ist
[
int
],
int
],
ids
:
Union
[
l
ist
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
convert_ids_to_tokens
(
def
convert_ids_to_tokens
(
self
,
self
,
ids
:
L
ist
[
int
],
ids
:
l
ist
[
int
],
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
class
TokenizerRegistry
:
class
TokenizerRegistry
:
# Tokenizer name -> (tokenizer module, tokenizer class)
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY
:
D
ict
[
str
,
T
uple
[
str
,
str
]]
=
{}
REGISTRY
:
d
ict
[
str
,
t
uple
[
str
,
str
]]
=
{}
@
staticmethod
@
staticmethod
def
register
(
name
:
str
,
module
:
str
,
class_name
:
str
)
->
None
:
def
register
(
name
:
str
,
module
:
str
,
class_name
:
str
)
->
None
:
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
93a126fb
...
@@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
...
@@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
# the following attributes are set to fit vLLM's design and are used
# the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends.
# by the guided structured output backends.
@
property
@
property
def
all_special_tokens_extended
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens_extended
(
self
)
->
l
ist
[
str
]:
from
mistral_common.tokens.tokenizers.base
import
SpecialTokens
from
mistral_common.tokens.tokenizers.base
import
SpecialTokens
# tekken defines its own extended special tokens list
# tekken defines its own extended special tokens list
...
@@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
...
@@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
]
]
@
property
@
property
def
all_special_tokens
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens
(
self
)
->
l
ist
[
str
]:
return
self
.
all_special_tokens_extended
return
self
.
all_special_tokens_extended
@
property
@
property
def
all_special_ids
(
self
)
->
L
ist
[
int
]:
def
all_special_ids
(
self
)
->
l
ist
[
int
]:
return
[
return
[
self
.
all_special_tokens
.
index
(
t
)
for
t
in
self
.
all_special_tokens
self
.
all_special_tokens
.
index
(
t
)
for
t
in
self
.
all_special_tokens
]
]
...
@@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
...
@@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
input_ids
=
self
.
encode_one
(
text
,
truncation
,
max_length
)
input_ids
=
self
.
encode_one
(
text
,
truncation
,
max_length
)
return
Encoding
(
input_ids
=
input_ids
)
return
Encoding
(
input_ids
=
input_ids
)
def
get_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_vocab
(
self
)
->
d
ict
[
str
,
int
]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
# to the same string but have different bytes
return
self
.
_vocab_dict
return
self
.
_vocab_dict
def
get_added_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_added_vocab
(
self
)
->
d
ict
[
str
,
int
]:
# Mistral tokenizers have no added vocabulary
# Mistral tokenizers have no added vocabulary
return
{}
return
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment