Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93a126fb
Unverified
Commit
93a126fb
authored
Apr 27, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 27, 2025
Browse files
[Misc] Make cached tokenizer pickle-compatible (#17048)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8e4b351a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
56 deletions
+80
-56
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+8
-6
tests/tokenization/test_cached_tokenizer.py
tests/tokenization/test_cached_tokenizer.py
+31
-12
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+19
-16
vllm/transformers_utils/tokenizer_base.py
vllm/transformers_utils/tokenizer_base.py
+17
-17
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+5
-5
No files found.
benchmarks/benchmark_prefix_caching.py
View file @
93a126fb
...
...
@@ -63,14 +63,16 @@ class Request:
output_len
:
int
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
str
:
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
list
[
int
]:
vocab
=
tokenizer
.
get_vocab
()
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
# Remove the special tokens.
vocab
=
{
k
:
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
tokenizer
.
all_special_ids
}
return
random
.
choices
(
list
(
vocab
.
values
()),
k
=
length
)
return
random
.
choices
(
[
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
all_special_ids
],
k
=
length
,
)
def
sample_requests_from_dataset
(
...
...
tests/tokenization/test_cached_tokenizer.py
View file @
93a126fb
# SPDX-License-Identifier: Apache-2.0
import
pickle
from
copy
import
deepcopy
import
pytest
from
transformers
import
AutoTokenizer
from
vllm.transformers_utils.tokenizer
import
get_cached_tokenizer
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
get_cached_tokenizer
)
def
test_cached_tokenizer
():
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"gpt2"
,
"THUDM/chatglm3-6b"
])
def
test_cached_tokenizer
(
model_id
:
str
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
trust_remote_code
=
True
)
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<SEP>"
]})
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
_check_consistency
(
cached_tokenizer
,
reference_tokenizer
)
pickled_tokenizer
=
pickle
.
dumps
(
cached_tokenizer
)
unpickled_tokenizer
=
pickle
.
loads
(
pickled_tokenizer
)
_check_consistency
(
unpickled_tokenizer
,
reference_tokenizer
)
def
_check_consistency
(
target
:
AnyTokenizer
,
expected
:
AnyTokenizer
):
assert
isinstance
(
target
,
type
(
expected
))
# Cached attributes
assert
target
.
all_special_ids
==
expected
.
all_special_ids
assert
target
.
all_special_tokens
==
expected
.
all_special_tokens
assert
(
target
.
all_special_tokens_extended
==
expected
.
all_special_tokens_extended
)
assert
target
.
get_vocab
()
==
expected
.
get_vocab
()
assert
len
(
target
)
==
len
(
expected
)
# Other attributes
assert
getattr
(
target
,
"padding_side"
,
None
)
==
getattr
(
expected
,
"padding_side"
,
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
cached_tokenizer
.
encode
(
"prompt"
)
assert
set
(
reference_tokenizer
.
all_special_ids
)
==
set
(
cached_tokenizer
.
all_special_ids
)
assert
set
(
reference_tokenizer
.
all_special_tokens
)
==
set
(
cached_tokenizer
.
all_special_tokens
)
assert
set
(
reference_tokenizer
.
all_special_tokens_extended
)
==
set
(
cached_tokenizer
.
all_special_tokens_extended
)
assert
target
.
encode
(
"prompt"
)
==
expected
.
encode
(
"prompt"
)
vllm/transformers_utils/tokenizer.py
View file @
93a126fb
# SPDX-License-Identifier: Apache-2.0
import
contextlib
import
copy
import
os
import
warnings
from
functools
import
lru_cache
...
...
@@ -70,18 +71,17 @@ def encode_tokens(
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer
=
copy
.
copy
(
tokenizer
)
tokenizer_all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
tokenizer_all_special_ids
=
tokenizer
.
all_special_ids
tokenizer_all_special_tokens
=
tokenizer
.
all_special_tokens
tokenizer_all_special_tokens_extended
=
(
tokenizer
.
all_special_tokens_extended
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_vocab
=
tokenizer
.
get_vocab
()
tokenizer_len
=
len
(
tokenizer
)
...
...
@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
@
property
def
all_special_ids
(
self
):
def
all_special_ids
(
self
)
->
list
[
int
]
:
return
tokenizer_all_special_ids
@
property
def
all_special_tokens
(
self
):
def
all_special_tokens
(
self
)
->
list
[
str
]
:
return
tokenizer_all_special_tokens
@
property
def
all_special_tokens_extended
(
self
):
def
all_special_tokens_extended
(
self
)
->
list
[
str
]
:
return
tokenizer_all_special_tokens_extended
@
property
def
max_token_id
(
self
):
def
max_token_id
(
self
)
->
int
:
return
max_token_id
def
get_vocab
(
self
):
def
get_vocab
(
self
)
->
dict
[
str
,
int
]
:
return
tokenizer_vocab
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
tokenizer_len
def
__reduce__
(
self
):
return
get_cached_tokenizer
,
(
tokenizer
,
)
CachedTokenizer
.
__name__
=
f
"Cached
{
tokenizer
.
__class__
.
__name__
}
"
tokenizer
.
__class__
=
CachedTokenizer
return
tokenizer
cached_
tokenizer
.
__class__
=
CachedTokenizer
return
cached_
tokenizer
def
patch_padding_side
(
tokenizer
:
PreTrainedTokenizer
)
->
None
:
...
...
vllm/transformers_utils/tokenizer_base.py
View file @
93a126fb
...
...
@@ -2,7 +2,7 @@
import
importlib
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
...
...
@@ -12,17 +12,17 @@ class TokenizerBase(ABC):
@
property
@
abstractmethod
def
all_special_tokens_extended
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens_extended
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_tokens
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens
(
self
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_ids
(
self
)
->
L
ist
[
int
]:
def
all_special_ids
(
self
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
@
property
...
...
@@ -66,7 +66,7 @@ class TokenizerBase(ABC):
@
abstractmethod
def
__call__
(
self
,
text
:
Union
[
str
,
L
ist
[
str
],
L
ist
[
int
]],
text
:
Union
[
str
,
l
ist
[
str
],
l
ist
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
...
...
@@ -75,11 +75,11 @@ class TokenizerBase(ABC):
raise
NotImplementedError
()
@
abstractmethod
def
get_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
get_added_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_added_vocab
(
self
)
->
d
ict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
...
...
@@ -88,44 +88,44 @@ class TokenizerBase(ABC):
text
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
L
ist
[
int
]:
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
apply_chat_template
(
self
,
messages
:
L
ist
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
L
ist
[
D
ict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
L
ist
[
int
]:
messages
:
l
ist
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
l
ist
[
d
ict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
l
ist
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
convert_tokens_to_string
(
self
,
tokens
:
L
ist
[
str
])
->
str
:
def
convert_tokens_to_string
(
self
,
tokens
:
l
ist
[
str
])
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
decode
(
self
,
ids
:
Union
[
L
ist
[
int
],
int
],
ids
:
Union
[
l
ist
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
convert_ids_to_tokens
(
self
,
ids
:
L
ist
[
int
],
ids
:
l
ist
[
int
],
skip_special_tokens
:
bool
=
True
,
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
raise
NotImplementedError
()
class
TokenizerRegistry
:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY
:
D
ict
[
str
,
T
uple
[
str
,
str
]]
=
{}
REGISTRY
:
d
ict
[
str
,
t
uple
[
str
,
str
]]
=
{}
@
staticmethod
def
register
(
name
:
str
,
module
:
str
,
class_name
:
str
)
->
None
:
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
93a126fb
...
...
@@ -257,7 +257,7 @@ class MistralTokenizer(TokenizerBase):
# the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends.
@
property
def
all_special_tokens_extended
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens_extended
(
self
)
->
l
ist
[
str
]:
from
mistral_common.tokens.tokenizers.base
import
SpecialTokens
# tekken defines its own extended special tokens list
...
...
@@ -271,11 +271,11 @@ class MistralTokenizer(TokenizerBase):
]
@
property
def
all_special_tokens
(
self
)
->
L
ist
[
str
]:
def
all_special_tokens
(
self
)
->
l
ist
[
str
]:
return
self
.
all_special_tokens_extended
@
property
def
all_special_ids
(
self
)
->
L
ist
[
int
]:
def
all_special_ids
(
self
)
->
l
ist
[
int
]:
return
[
self
.
all_special_tokens
.
index
(
t
)
for
t
in
self
.
all_special_tokens
]
...
...
@@ -335,12 +335,12 @@ class MistralTokenizer(TokenizerBase):
input_ids
=
self
.
encode_one
(
text
,
truncation
,
max_length
)
return
Encoding
(
input_ids
=
input_ids
)
def
get_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_vocab
(
self
)
->
d
ict
[
str
,
int
]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return
self
.
_vocab_dict
def
get_added_vocab
(
self
)
->
D
ict
[
str
,
int
]:
def
get_added_vocab
(
self
)
->
d
ict
[
str
,
int
]:
# Mistral tokenizers have no added vocabulary
return
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment