Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6fc4e6e0
Unverified
Commit
6fc4e6e0
authored
Aug 27, 2024
by
Patrick von Platen
Committed by
GitHub
Aug 27, 2024
Browse files
[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)
parent
9606c719
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
275 additions
and
60 deletions
+275
-60
docs/requirements-docs.txt
docs/requirements-docs.txt
+1
-0
requirements-common.txt
requirements-common.txt
+1
-0
tests/models/test_mistral.py
tests/models/test_mistral.py
+3
-1
vllm/config.py
vllm/config.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-2
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+1
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+9
-3
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+18
-8
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+1
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+58
-36
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+2
-3
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+174
-0
No files found.
docs/requirements-docs.txt
View file @
6fc4e6e0
...
...
@@ -11,4 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requirements-common.txt
View file @
6fc4e6e0
...
...
@@ -26,3 +26,4 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
tests/models/test_mistral.py
View file @
6fc4e6e0
...
...
@@ -30,9 +30,11 @@ def test_models(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
tokenizer_mode
=
"mistral"
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
...
...
vllm/config.py
View file @
6fc4e6e0
...
...
@@ -61,7 +61,8 @@ class ModelConfig:
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
...
...
@@ -246,10 +247,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
]:
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto'
or 'slow
'."
)
"either 'auto'
, 'slow' or 'mistral
'."
)
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_embedding_mode
(
self
)
->
None
:
...
...
vllm/engine/arg_utils.py
View file @
6fc4e6e0
...
...
@@ -198,10 +198,11 @@ class EngineArgs:
'--tokenizer-mode'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.'
)
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'Trust remote code from huggingface.'
)
...
...
vllm/entrypoints/chat_utils.py
View file @
6fc4e6e0
...
...
@@ -267,7 +267,7 @@ def apply_chat_template(
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
)
->
str
:
)
->
Union
[
str
,
List
[
int
]]
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
...
...
@@ -280,6 +280,4 @@ def apply_chat_template(
tokenize
=
tokenize
,
**
kwargs
,
)
assert
isinstance
(
prompt
,
str
)
return
prompt
vllm/entrypoints/llm.py
View file @
6fc4e6e0
...
...
@@ -390,15 +390,21 @@ class LLM:
conversations
,
_
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
prompt
s
=
apply_chat_template
(
prompt
=
apply_chat_template
(
tokenizer
,
conversations
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
)
inputs
:
PromptInputs
if
isinstance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
],
int
):
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
return
self
.
generate
(
promp
ts
,
sampling_params
,
inpu
ts
,
sampling_params
=
sampling_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
6fc4e6e0
...
...
@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
,
PromptAdapterPath
)
PromptAdapterPath
,
TextTokensPrompt
)
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
...
...
@@ -130,13 +131,22 @@ class OpenAIServingChat(OpenAIServing):
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
if
isinstance
(
prompt
,
str
):
prompt_inputs
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
prompt
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
else
:
assert
isinstance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
],
int
),
"Prompt has to be either a string or a list of token ids"
prompt_inputs
=
TextTokensPrompt
(
prompt
=
tokenizer
.
decode
(
prompt
),
prompt_token_ids
=
prompt
)
assert
prompt_inputs
is
not
None
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
...
...
vllm/transformers_utils/detokenizer.py
View file @
6fc4e6e0
...
...
@@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
prefix_offset
=
max
(
read_offset
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty
(
new_tokens
)
_replace_none_with_empty
(
new_tokens
)
# type: ignore[arg-type]
return
new_tokens
,
prefix_offset
,
read_offset
...
...
vllm/transformers_utils/tokenizer.py
View file @
6fc4e6e0
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Optional
,
Union
...
...
@@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
from
vllm.transformers_utils.tokenizers
import
(
BaichuanTokenizer
,
MistralTokenizer
)
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
MistralTokenizer
]
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
...
...
@@ -99,45 +102,64 @@ def get_tokenizer(
kwargs
[
"gguf_file"
]
=
Path
(
tokenizer_name
).
name
tokenizer_name
=
Path
(
tokenizer_name
).
parent
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if
(
not
trust_remote_code
and
(
"does not exist or is not currently imported."
in
str
(
e
)
or
"requires you to execute the tokenizer file"
in
str
(
e
))):
err_msg
=
(
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
except
AttributeError
as
e
:
if
"BaichuanTokenizer"
in
str
(
e
):
# This is for the error "'BaichuanTokenizer' object has no
# attribute 'sp_model'".
tokenizer
=
BaichuanTokenizer
.
from_pretrained
(
# if tokenizer is from official mistral org
is_from_mistral_org
=
str
(
tokenizer_name
).
split
(
"/"
)[
0
]
==
"mistralai"
if
is_from_mistral_org
and
tokenizer_mode
!=
"mistral"
:
warnings
.
warn
(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'encoding and decoding.'
,
FutureWarning
,
stacklevel
=
2
)
if
tokenizer_mode
==
"mistral"
:
tokenizer
=
MistralTokenizer
.
from_pretrained
(
str
(
tokenizer_name
),
revision
=
revision
)
else
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
else
:
raise
e
**
kwargs
,
)
except
ValueError
as
e
:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if
not
trust_remote_code
and
(
"does not exist or is not currently imported."
in
str
(
e
)
or
"requires you to execute the tokenizer file"
in
str
(
e
)):
err_msg
=
(
"Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
except
AttributeError
as
e
:
if
"BaichuanTokenizer"
in
str
(
e
):
# This is for the error "'BaichuanTokenizer' object has no
# attribute 'sp_model'".
tokenizer
=
BaichuanTokenizer
.
from_pretrained
(
tokenizer_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
,
)
else
:
raise
e
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
logger
.
warning
(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
logger
.
warning
(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
return
get_cached_tokenizer
(
tokenizer
)
return
tokenizer
def
get_lora_tokenizer
(
lora_request
:
LoRARequest
,
*
args
,
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
6fc4e6e0
from
vllm.transformers_utils.tokenizers.baichuan
import
BaichuanTokenizer
from
vllm.transformers_utils.tokenizers.mistral
import
MistralTokenizer
__all__
=
[
"BaichuanTokenizer"
,
]
__all__
=
[
"BaichuanTokenizer"
,
"MistralTokenizer"
]
vllm/transformers_utils/tokenizers/mistral.py
0 → 100644
View file @
6fc4e6e0
import
os
import
re
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
huggingface_hub
import
HfApi
,
hf_hub_download
# yapf: disable
from
mistral_common.tokens.tokenizers.mistral
import
ChatCompletionRequest
from
mistral_common.tokens.tokenizers.mistral
import
(
MistralTokenizer
as
PublicMistralTokenizer
)
# yapf: enable
from
mistral_common.tokens.tokenizers.sentencepiece
import
(
SentencePieceTokenizer
)
from
mistral_common.tokens.tokenizers.tekken
import
(
SpecialTokenPolicy
,
Tekkenizer
)
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ConversationMessage
@
dataclass
class
Encoding
:
input_ids
:
List
[
int
]
def
find_tokenizer_file
(
files
:
List
[
str
]):
file_pattern
=
re
.
compile
(
r
"^tokenizer\.model\.v.*$|^tekken\.json$"
)
matched_files
=
[
file
for
file
in
files
if
file_pattern
.
match
(
file
)]
if
len
(
matched_files
)
>
1
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
"pattern: {matched_files}. Make sure only one Mistral "
"tokenizer is present in {tokenizer_name}."
)
elif
len
(
matched_files
)
==
0
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
"pattern: {matched_files}. Make sure that a Mistral "
"tokenizer is present in {tokenizer_name}."
)
return
matched_files
[
0
]
class
MistralTokenizer
:
def
__init__
(
self
,
tokenizer
:
PublicMistralTokenizer
)
->
None
:
self
.
mistral
=
tokenizer
self
.
instruct
=
tokenizer
.
instruct_tokenizer
self
.
tokenizer
=
tokenizer
.
instruct_tokenizer
.
tokenizer
self
.
vocab_size
=
len
(
self
.
tokenizer
.
vocab
())
assert
isinstance
(
self
.
tokenizer
,
(
Tekkenizer
,
SentencePieceTokenizer
)),
type
(
self
.
tokenizer
)
self
.
_is_tekken
=
isinstance
(
self
.
tokenizer
,
Tekkenizer
)
if
self
.
_is_tekken
:
# Make sure special tokens will not raise
self
.
tokenizer
.
special_token_policy
=
SpecialTokenPolicy
.
IGNORE
# the following attributes are set to fit VLLM's design
self
.
is_fast
=
True
self
.
chat_template
=
True
self
.
all_special_ids
:
List
[
Any
]
=
[]
self
.
all_special_tokens
:
List
[
Any
]
=
[]
self
.
all_special_tokens_extended
:
List
[
Any
]
=
[]
@
classmethod
def
from_pretrained
(
cls
,
path_or_repo_id
:
str
,
*
,
revision
:
Optional
[
str
]
=
None
)
->
"MistralTokenizer"
:
if
not
Path
(
path_or_repo_id
).
exists
():
assert
len
(
path_or_repo_id
.
split
(
"/"
))
==
2
,
(
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id."
)
tokenizer_file
=
cls
.
_download_mistral_tokenizer_from_hf
(
path_or_repo_id
,
revision
)
elif
Path
(
path_or_repo_id
).
is_dir
():
tokenizer_file_name
=
find_tokenizer_file
(
os
.
listdir
(
path_or_repo_id
))
tokenizer_file
=
str
(
Path
(
path_or_repo_id
)
/
tokenizer_file_name
)
else
:
assert
Path
(
path_or_repo_id
).
is_file
(),
f
"Invalid path:
{
path_or_repo_id
}
"
mistral_tokenizer
=
PublicMistralTokenizer
.
from_file
(
tokenizer_file
)
return
cls
(
mistral_tokenizer
)
@
staticmethod
def
_download_mistral_tokenizer_from_hf
(
tokenizer_name
:
str
,
revision
:
Optional
[
str
])
->
str
:
api
=
HfApi
()
repo_info
=
api
.
model_info
(
tokenizer_name
)
files
=
[
s
.
rfilename
for
s
in
repo_info
.
siblings
]
filename
=
find_tokenizer_file
(
files
)
tokenizer_file
=
hf_hub_download
(
tokenizer_name
,
filename
=
filename
,
revision
=
revision
)
return
tokenizer_file
def
__call__
(
self
,
prompt
:
str
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
# Mistral Tokenizers should not add special tokens
input_ids
=
self
.
encode
(
prompt
)
if
truncation
:
input_ids
=
input_ids
[:
max_length
]
return
Encoding
(
input_ids
=
input_ids
)
def
get_added_vocab
(
self
)
->
List
[
str
]:
# Mistral tokenizers have no added vocabulary
return
[]
def
encode
(
self
,
prompt
:
str
)
->
List
[
int
]:
# `encode ` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return
self
.
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
)
def
apply_chat_template
(
self
,
conversation
:
List
[
"ConversationMessage"
],
tools
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
)
->
List
[
int
]:
assert
tools
is
None
,
"`tools` are not yet supported."
request
=
ChatCompletionRequest
(
messages
=
conversation
)
# type: ignore[type-var]
encoded
=
self
.
mistral
.
encode_chat_completion
(
request
)
# encode-decode to get clean prompt
return
encoded
.
tokens
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
if
self
.
_is_tekken
:
return
""
.
join
(
tokens
)
else
:
return
self
.
tokenizer
.
decode
(
tokens
)
# type: ignore[arg-type]
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
])
->
str
:
if
isinstance
(
ids
,
int
):
ids
=
[
ids
]
return
self
.
tokenizer
.
decode
(
ids
)
@
property
def
eos_token_id
(
self
):
return
self
.
tokenizer
.
eos_id
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
],
skip_special_tokens
:
Optional
[
bool
]
=
True
)
->
List
[
str
]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert
(
skip_special_tokens
),
"Skipping special tokens is not supported for Mistral tokenizers."
assert
isinstance
(
self
.
tokenizer
,
(
Tekkenizer
,
SentencePieceTokenizer
)),
type
(
self
.
tokenizer
)
tokens
=
[
self
.
tokenizer
.
id_to_piece
(
id
)
for
id
in
ids
]
return
tokens
def
__len__
(
self
):
return
self
.
vocab_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment