Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ee696a6
Unverified
Commit
3ee696a6
authored
Feb 11, 2025
by
Keyun Tong
Committed by
GitHub
Feb 12, 2025
Browse files
[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by:
Keyun Tong
<
tongkeyun@gmail.com
>
parent
72c2b68d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
343 additions
and
41 deletions
+343
-41
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+3
-2
tests/tokenization/test_tokenizer_registry.py
tests/tokenization/test_tokenizer_registry.py
+123
-0
vllm/config.py
vllm/config.py
+5
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+19
-12
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-2
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+1
-1
vllm/logits_process.py
vllm/logits_process.py
+1
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+12
-6
vllm/transformers_utils/tokenizer_base.py
vllm/transformers_utils/tokenizer_base.py
+146
-0
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+28
-11
No files found.
benchmarks/benchmark_serving.py
View file @
3ee696a6
...
@@ -1275,11 +1275,12 @@ if __name__ == "__main__":
...
@@ -1275,11 +1275,12 @@ if __name__ == "__main__":
'--tokenizer-mode'
,
'--tokenizer-mode'
,
type
=
str
,
type
=
str
,
default
=
"auto"
,
default
=
"auto"
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
'"mistral" will always use the `mistral_common` tokenizer.
\n
*'
'"custom" will use --tokenizer to select the preregistered tokenizer.'
)
parser
.
add_argument
(
"--served-model-name"
,
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
type
=
str
,
...
...
tests/tokenization/test_tokenizer_registry.py
0 → 100644
View file @
3ee696a6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
TokenizerRegistry
)
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
class
TestTokenizer
(
TokenizerBase
):
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
)
->
"TestTokenizer"
:
return
TestTokenizer
()
@
property
def
all_special_tokens_extended
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
def
all_special_tokens
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
def
all_special_ids
(
self
)
->
List
[
int
]:
raise
NotImplementedError
()
@
property
def
bos_token_id
(
self
)
->
int
:
return
0
@
property
def
eos_token_id
(
self
)
->
int
:
return
1
@
property
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
is_fast
(
self
)
->
bool
:
raise
NotImplementedError
()
@
property
def
vocab_size
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
def
max_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
def
__call__
(
self
,
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
raise
NotImplementedError
()
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
def
get_added_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
def
encode_one
(
self
,
text
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
raise
NotImplementedError
()
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
raise
NotImplementedError
()
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
List
[
int
]:
raise
NotImplementedError
()
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
raise
NotImplementedError
()
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
)
->
List
[
str
]:
raise
NotImplementedError
()
def
test_customized_tokenizer
():
TokenizerRegistry
.
register
(
"test_tokenizer"
,
"tests.tokenization.test_tokenizer_registry"
,
"TestTokenizer"
)
tokenizer
=
TokenizerRegistry
.
get_tokenizer
(
"test_tokenizer"
)
assert
isinstance
(
tokenizer
,
TestTokenizer
)
assert
tokenizer
.
bos_token_id
==
0
assert
tokenizer
.
eos_token_id
==
1
tokenizer
=
get_tokenizer
(
"test_tokenizer"
,
tokenizer_mode
=
"custom"
)
assert
isinstance
(
tokenizer
,
TestTokenizer
)
assert
tokenizer
.
bos_token_id
==
0
assert
tokenizer
.
eos_token_id
==
1
vllm/config.py
View file @
3ee696a6
...
@@ -102,8 +102,9 @@ class ModelConfig:
...
@@ -102,8 +102,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`.
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
allowed_local_media_path: Allowing API requests to read local images or
...
@@ -467,10 +468,10 @@ class ModelConfig:
...
@@ -467,10 +468,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]:
raise
ValueError
(
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto', 'slow'
or
'mistral'."
)
"either 'auto', 'slow'
,
'mistral'
or 'custom'
."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_get_preferred_task
(
def
_get_preferred_task
(
...
...
vllm/engine/arg_utils.py
View file @
3ee696a6
...
@@ -284,11 +284,13 @@ class EngineArgs:
...
@@ -284,11 +284,13 @@ class EngineArgs:
'--tokenizer-mode'
,
'--tokenizer-mode'
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
'"mistral" will always use the `mistral_common` tokenizer.
\n
* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.'
)
parser
.
add_argument
(
'--trust-remote-code'
,
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Trust remote code from huggingface.'
)
help
=
'Trust remote code from huggingface.'
)
...
...
vllm/entrypoints/llm.py
View file @
3ee696a6
...
@@ -1051,9 +1051,9 @@ class LLM:
...
@@ -1051,9 +1051,9 @@ class LLM:
def
_cross_encoding_score
(
def
_cross_encoding_score
(
self
,
self
,
tokenizer
:
Union
[
AnyTokenizer
]
,
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
text_1
:
List
[
str
],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
text_2
:
List
[
str
],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
...
@@ -1176,29 +1176,36 @@ class LLM:
...
@@ -1176,29 +1176,36 @@ class LLM:
if
isinstance
(
text_1
,
(
str
,
dict
)):
if
isinstance
(
text_1
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
text_1
=
[
text_1
]
text_1
=
[
text_1
]
text_1
=
[
ensure_str
(
t
)
for
t
in
text_1
]
input_text_1
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_1
]
if
isinstance
(
text_2
,
(
str
,
dict
)):
if
isinstance
(
text_2
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
text_2
=
[
text_2
]
text_2
=
[
text_2
]
text_2
=
[
ensure_str
(
t
)
for
t
in
text_2
]
input_text_2
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_2
]
if
len
(
text_1
)
>
1
and
len
(
text_1
)
!=
len
(
text_2
):
if
len
(
input_
text_1
)
>
1
and
len
(
input_
text_1
)
!=
len
(
input_
text_2
):
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
if
len
(
text_1
)
==
0
:
if
len
(
input_
text_1
)
==
0
:
raise
ValueError
(
"At least one text element must be given"
)
raise
ValueError
(
"At least one text element must be given"
)
if
len
(
text_2
)
==
0
:
if
len
(
input_
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
raise
ValueError
(
"At least one text_pair element must be given"
)
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
return
self
.
_cross_encoding_score
(
tokenizer
,
input_text_1
,
input_text_2
,
truncate_prompt_tokens
,
use_tqdm
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
lora_request
,
prompt_adapter_request
)
prompt_adapter_request
)
else
:
else
:
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
return
self
.
_embedding_score
(
lora_request
,
prompt_adapter_request
)
tokenizer
,
input_text_1
,
# type: ignore[arg-type]
input_text_2
,
# type: ignore[arg-type]
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
self
.
llm_engine
.
start_profile
()
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
3ee696a6
...
@@ -400,8 +400,7 @@ class OpenAIServing:
...
@@ -400,8 +400,7 @@ class OpenAIServing:
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
request_prompt
:
Union
[
str
,
List
[
int
]]
request_prompt
:
Union
[
str
,
List
[
int
]]
is_mistral_tokenizer
=
isinstance
(
tokenizer
,
MistralTokenizer
)
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
is_mistral_tokenizer
:
request_prompt
=
apply_mistral_chat_template
(
request_prompt
=
apply_mistral_chat_template
(
tokenizer
,
tokenizer
,
messages
=
messages
,
messages
=
messages
,
...
...
vllm/entrypoints/openai/serving_score.py
View file @
3ee696a6
...
@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
...
@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
executor
=
self
.
_tokenizer_executor
)
prompt_inputs
=
await
tokenize_async
(
text
=
q
,
prompt_inputs
=
await
tokenize_async
(
q
,
text_pair
=
t
,
text_pair
=
t
,
**
tokenization_kwargs
)
**
tokenization_kwargs
)
...
...
vllm/logits_process.py
View file @
3ee696a6
...
@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
...
@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# Mistral tokenizers should not add special tokens
# Mistral tokenizers should not add special tokens
prompt_token_ids
=
tokenizer
.
encode
(
promp
t
=
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
tex
t
=
prompt
)
else
:
else
:
prompt_token_ids
=
tokenizer
.
encode
(
text
=
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
text
=
prompt
,
add_special_tokens
=
False
)
add_special_tokens
=
False
)
...
...
vllm/transformers_utils/tokenizer.py
View file @
3ee696a6
...
@@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
...
@@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
TokenizerRegistry
)
from
vllm.transformers_utils.tokenizers
import
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
MistralTokenizer
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
...
@@ -21,7 +23,7 @@ from vllm.utils import make_async
...
@@ -21,7 +23,7 @@ from vllm.utils import make_async
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
Mistral
Tokenizer
]
Tokenizer
Base
]
def
decode_tokens
(
def
decode_tokens
(
...
@@ -47,11 +49,7 @@ def encode_tokens(
...
@@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
"""
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
add_special_tokens
is
not
None
:
return
tokenizer
.
tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
,
eos
=
add_special_tokens
)
elif
add_special_tokens
is
not
None
:
return
tokenizer
.
encode
(
text
,
add_special_tokens
=
add_special_tokens
)
return
tokenizer
.
encode
(
text
,
add_special_tokens
=
add_special_tokens
)
return
tokenizer
.
encode
(
text
)
return
tokenizer
.
encode
(
text
)
...
@@ -183,9 +181,17 @@ def get_tokenizer(
...
@@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.'
,
'encoding and decoding.'
,
FutureWarning
,
FutureWarning
,
stacklevel
=
2
)
stacklevel
=
2
)
tokenizer
:
AnyTokenizer
if
tokenizer_mode
==
"mistral"
:
if
tokenizer_mode
==
"mistral"
:
tokenizer
=
MistralTokenizer
.
from_pretrained
(
str
(
tokenizer_name
),
tokenizer
=
MistralTokenizer
.
from_pretrained
(
str
(
tokenizer_name
),
revision
=
revision
)
revision
=
revision
)
elif
tokenizer_mode
==
"custom"
:
tokenizer
=
TokenizerRegistry
.
get_tokenizer
(
str
(
tokenizer_name
),
*
args
,
revision
=
revision
,
download_dir
=
download_dir
,
**
kwargs
)
else
:
else
:
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
...
...
vllm/transformers_utils/tokenizer_base.py
0 → 100644
View file @
3ee696a6
# SPDX-License-Identifier: Apache-2.0
import
importlib
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
class
TokenizerBase
(
ABC
):
@
property
@
abstractmethod
def
all_special_tokens_extended
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_tokens
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_ids
(
self
)
->
List
[
int
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
bos_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
eos_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
is_fast
(
self
)
->
bool
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
vocab_size
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
max_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
def
__len__
(
self
)
->
int
:
return
self
.
vocab_size
@
abstractmethod
def
__call__
(
self
,
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
raise
NotImplementedError
()
@
abstractmethod
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
get_added_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
encode_one
(
self
,
text
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
)
->
List
[
str
]:
raise
NotImplementedError
()
class
TokenizerRegistry
:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY
:
Dict
[
str
,
Tuple
[
str
,
str
]]
=
{}
@
staticmethod
def
register
(
name
:
str
,
module
:
str
,
class_name
:
str
)
->
None
:
TokenizerRegistry
.
REGISTRY
[
name
]
=
(
module
,
class_name
)
@
staticmethod
def
get_tokenizer
(
tokenizer_name
:
str
,
*
args
,
**
kwargs
,
)
->
TokenizerBase
:
tokenizer_cls
=
TokenizerRegistry
.
REGISTRY
.
get
(
tokenizer_name
)
if
tokenizer_cls
is
None
:
raise
ValueError
(
f
"Tokenizer
{
tokenizer_name
}
not found."
)
tokenizer_module
=
importlib
.
import_module
(
tokenizer_cls
[
0
])
class_
=
getattr
(
tokenizer_module
,
tokenizer_cls
[
1
])
return
class_
.
from_pretrained
(
*
args
,
**
kwargs
)
vllm/transformers_utils/tokenizers/mistral.py
View file @
3ee696a6
...
@@ -10,6 +10,7 @@ import huggingface_hub
...
@@ -10,6 +10,7 @@ import huggingface_hub
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_base
import
TokenizerBase
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
...
@@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
tools
=
tools
)
# type: ignore[type-var]
tools
=
tools
)
# type: ignore[type-var]
class
MistralTokenizer
:
class
MistralTokenizer
(
TokenizerBase
)
:
def
__init__
(
self
,
tokenizer
:
"PublicMistralTokenizer"
)
->
None
:
def
__init__
(
self
,
tokenizer
:
"PublicMistralTokenizer"
)
->
None
:
self
.
mistral
=
tokenizer
self
.
mistral
=
tokenizer
...
@@ -251,6 +252,14 @@ class MistralTokenizer:
...
@@ -251,6 +252,14 @@ class MistralTokenizer:
def
eos_token_id
(
self
)
->
int
:
def
eos_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
eos_id
return
self
.
tokenizer
.
eos_id
@
property
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
@
property
def
is_fast
(
self
)
->
bool
:
def
is_fast
(
self
)
->
bool
:
return
True
return
True
...
@@ -268,25 +277,26 @@ class MistralTokenizer:
...
@@ -268,25 +277,26 @@ class MistralTokenizer:
def
__call__
(
def
__call__
(
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
):
):
input_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]]
input_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]]
# For List[str], original prompt text
# For List[str], original prompt text
if
is_list_of
(
promp
t
,
str
):
if
is_list_of
(
tex
t
,
str
):
input_ids_
:
List
[
List
[
int
]]
=
[]
input_ids_
:
List
[
List
[
int
]]
=
[]
for
p
in
promp
t
:
for
p
in
tex
t
:
each_input_ids
=
self
.
encode_one
(
p
,
truncation
,
max_length
)
each_input_ids
=
self
.
encode_one
(
p
,
truncation
,
max_length
)
input_ids_
.
append
(
each_input_ids
)
input_ids_
.
append
(
each_input_ids
)
input_ids
=
input_ids_
input_ids
=
input_ids_
# For List[int], apply chat template output, already tokens.
# For List[int], apply chat template output, already tokens.
elif
is_list_of
(
promp
t
,
int
):
elif
is_list_of
(
tex
t
,
int
):
input_ids
=
promp
t
input_ids
=
tex
t
# For str, single prompt text
# For str, single prompt text
else
:
else
:
input_ids
=
self
.
encode_one
(
promp
t
,
truncation
,
max_length
)
input_ids
=
self
.
encode_one
(
tex
t
,
truncation
,
max_length
)
return
Encoding
(
input_ids
=
input_ids
)
return
Encoding
(
input_ids
=
input_ids
)
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
...
@@ -300,22 +310,29 @@ class MistralTokenizer:
...
@@ -300,22 +310,29 @@ class MistralTokenizer:
def
encode_one
(
def
encode_one
(
self
,
self
,
promp
t
:
str
,
tex
t
:
str
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
)
->
List
[
int
]:
# Mistral Tokenizers should not add special tokens
# Mistral Tokenizers should not add special tokens
input_ids
=
self
.
encode
(
promp
t
)
input_ids
=
self
.
encode
(
tex
t
)
if
truncation
:
if
truncation
:
input_ids
=
input_ids
[:
max_length
]
input_ids
=
input_ids
[:
max_length
]
return
input_ids
return
input_ids
def
encode
(
self
,
prompt
:
str
)
->
List
[
int
]:
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
# `encode` should only be used for prompt completion
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
# For chat completion use `apply_chat_template`
return
self
.
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
)
if
add_special_tokens
is
not
None
:
return
self
.
tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
,
eos
=
add_special_tokens
)
else
:
return
self
.
tokenizer
.
encode
(
text
,
bos
=
True
,
eos
=
False
)
def
apply_chat_template
(
self
,
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
messages
:
List
[
"ChatCompletionMessageParam"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment