Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c054b7a
Unverified
Commit
8c054b7a
authored
Sep 11, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 10, 2024
Browse files
[Frontend] Clean up type annotations for mistral tokenizer (#8314)
parent
6234385f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
114 additions
and
59 deletions
+114
-59
tests/async_engine/test_chat_template.py
tests/async_engine/test_chat_template.py
+3
-2
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+41
-20
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+18
-8
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+30
-18
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+18
-7
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+4
-4
No files found.
tests/async_engine/test_chat_template.py
View file @
8c054b7a
import
pytest
import
pytest
from
vllm.entrypoints.chat_utils
import
apply_chat_template
,
load_chat_template
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
load_chat_template
)
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
@@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt
=
add_generation_prompt
)
add_generation_prompt
=
add_generation_prompt
)
# Call the function and get the result
# Call the function and get the result
result
=
apply_chat_template
(
result
=
apply_
hf_
chat_template
(
tokenizer
,
tokenizer
,
conversation
=
mock_request
.
messages
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
...
...
vllm/entrypoints/chat_utils.py
View file @
8c054b7a
...
@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
...
@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
# yapf: enable
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
# pydantic needs the TypedDict from typing_extensions
from
pydantic
import
ConfigDict
from
pydantic
import
ConfigDict
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
...
@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
,
async_get_and_parse_image
,
get_and_parse_audio
,
get_and_parse_image
)
get_and_parse_audio
,
get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
...
@@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
mm_parser
.
parse_audio
(
audio_url
[
"url"
])
mm_parser
.
parse_audio
(
audio_url
[
"url"
])
elif
part_type
==
"refusal"
:
text
=
_RefusalParser
(
part
)[
"refusal"
]
texts
.
append
(
text
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
...
@@ -433,6 +437,21 @@ def _parse_chat_message_content(
...
@@ -433,6 +437,21 @@ def _parse_chat_message_content(
return
result
return
result
def
_postprocess_messages
(
messages
:
List
[
ConversationMessage
])
->
None
:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for
message
in
messages
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)):
for
item
in
message
[
"tool_calls"
]:
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
])
def
parse_chat_messages
(
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -446,6 +465,8 @@ def parse_chat_messages(
...
@@ -446,6 +465,8 @@ def parse_chat_messages(
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
return
conversation
,
mm_tracker
.
all_mm_data
()
...
@@ -462,41 +483,41 @@ def parse_chat_messages_futures(
...
@@ -462,41 +483,41 @@ def parse_chat_messages_futures(
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
return
conversation
,
mm_tracker
.
all_mm_data
()
def
apply_chat_template
(
def
apply_
hf_
chat_template
(
tokenizer
:
Any
Tokenizer
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrained
Tokenizer
Fast
]
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
*
,
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
Union
[
str
,
List
[
int
]]
:
)
->
str
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
"does not define one."
)
# per the Transformers docs & maintainers, tool call arguments in
return
tokenizer
.
apply_chat_template
(
# assistant-role messages with tool_calls need to be dicts not JSON str -
conversation
=
conversation
,
# type: ignore[arg-type]
# this is how tool-use chat templates will expect them moving forwards
chat_template
=
chat_template
,
# so, for messages that have tool_calls, parse the string (which we get
tokenize
=
tokenize
,
# from openAI format) to dict
**
kwargs
,
for
message
in
conversation
:
)
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)):
for
i
in
range
(
len
(
message
[
"tool_calls"
])):
args
:
str
=
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
parsed_args
:
Dict
=
json
.
loads
(
args
)
message
[
"tool_calls"
][
i
][
"function"
][
"arguments"
]
=
parsed_args
prompt
=
tokenizer
.
apply_chat_template
(
def
apply_mistral_chat_template
(
conversation
=
conversation
,
tokenizer
:
MistralTokenizer
,
messages
:
List
[
ChatCompletionMessageParam
],
chat_template
:
Optional
[
str
],
**
kwargs
:
Any
,
)
->
List
[
int
]:
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
tokenize
=
tokenize
,
**
kwargs
,
**
kwargs
,
)
)
return
prompt
vllm/entrypoints/llm.py
View file @
8c054b7a
...
@@ -6,7 +6,8 @@ from tqdm import tqdm
...
@@ -6,7 +6,8 @@ from tqdm import tqdm
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
parse_chat_messages
)
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
...
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
...
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -393,12 +394,21 @@ class LLM:
...
@@ -393,12 +394,21 @@ class LLM:
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
tokenizer
)
prompt
=
apply_chat_template
(
prompt
:
Union
[
str
,
List
[
int
]]
tokenizer
,
if
isinstance
(
tokenizer
,
MistralTokenizer
):
conversation
,
prompt
=
apply_mistral_chat_template
(
chat_template
=
chat_template
,
tokenizer
,
add_generation_prompt
=
add_generation_prompt
,
messages
=
messages
,
)
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
)
inputs
:
PromptInputs
inputs
:
PromptInputs
if
is_list_of
(
prompt
,
int
):
if
is_list_of
(
prompt
,
int
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
8c054b7a
...
@@ -11,7 +11,8 @@ from fastapi import Request
...
@@ -11,7 +11,8 @@ from fastapi import Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
...
@@ -35,7 +36,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -121,15 +122,27 @@ class OpenAIServingChat(OpenAIServing):
tool
.
model_dump
()
for
tool
in
request
.
tools
tool
.
model_dump
()
for
tool
in
request
.
tools
]
]
prompt
=
apply_chat_template
(
prompt
:
Union
[
str
,
List
[
int
]]
tokenizer
,
if
isinstance
(
tokenizer
,
MistralTokenizer
):
conversation
=
conversation
,
prompt
=
apply_mistral_chat_template
(
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
tokenizer
,
add_generation_prompt
=
request
.
add_generation_prompt
,
messages
=
request
.
messages
,
tools
=
tool_dicts
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
documents
=
request
.
documents
,
add_generation_prompt
=
request
.
add_generation_prompt
,
**
(
request
.
chat_template_kwargs
or
{}),
tools
=
tool_dicts
,
)
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -307,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -307,11 +320,10 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# Send response to echo the input portion of the
# last message
# last message
if
request
.
echo
:
if
request
.
echo
:
last_msg_content
:
Optional
[
str
]
=
""
last_msg_content
:
str
=
""
if
conversation
and
conversation
[
-
1
].
get
(
if
conversation
and
"content"
in
conversation
[
"content"
)
and
conversation
[
-
1
].
get
(
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
last_msg_content
=
conversation
[
-
1
][
"content"
]
if
last_msg_content
:
if
last_msg_content
:
for
i
in
range
(
num_choices
):
for
i
in
range
(
num_choices
):
...
@@ -659,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -659,8 +671,8 @@ class OpenAIServingChat(OpenAIServing):
if
request
.
echo
:
if
request
.
echo
:
last_msg_content
=
""
last_msg_content
=
""
if
conversation
and
conversation
[
-
1
]
.
get
(
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
"content"
)
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
-
1
].
get
(
"role"
)
==
role
:
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
last_msg_content
=
conversation
[
-
1
][
"content"
]
or
""
for
choice
in
choices
:
for
choice
in
choices
:
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
8c054b7a
...
@@ -2,7 +2,8 @@ from typing import List, Optional, Union
...
@@ -2,7 +2,8 @@ from typing import List, Optional, Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages_futures
)
parse_chat_messages_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
...
@@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
...
@@ -18,6 +19,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -66,6 +68,7 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
...
@@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -77,12 +80,20 @@ class OpenAIServingTokenization(OpenAIServing):
logger
.
warning
(
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
"Multi-modal inputs are ignored during tokenization"
)
prompt
=
apply_chat_template
(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tokenizer
,
prompt
=
apply_mistral_chat_template
(
conversation
=
conversation
,
tokenizer
,
chat_template
=
self
.
chat_template
,
messages
=
request
.
messages
,
add_generation_prompt
=
request
.
add_generation_prompt
,
chat_template
=
self
.
chat_template
,
)
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
prompt
=
apply_hf_chat_template
(
tokenizer
,
conversation
=
conversation
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
else
:
prompt
=
request
.
prompt
prompt
=
request
.
prompt
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
8c054b7a
...
@@ -16,7 +16,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
...
@@ -16,7 +16,7 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer
)
Tekkenizer
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
C
onversa
tionMessage
from
vllm.entrypoints.chat_utils
import
C
hatComple
tionMessage
Param
@
dataclass
@
dataclass
...
@@ -122,19 +122,19 @@ class MistralTokenizer:
...
@@ -122,19 +122,19 @@ class MistralTokenizer:
return
[]
return
[]
def
encode
(
self
,
prompt
:
str
)
->
List
[
int
]:
def
encode
(
self
,
prompt
:
str
)
->
List
[
int
]:
# `encode
` should only be used for prompt completion
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
# For chat completion use `apply_chat_template`
return
self
.
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
)
return
self
.
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
)
def
apply_chat_template
(
self
,
def
apply_chat_template
(
self
,
conversation
:
List
[
"Conversa
tionMessage"
],
messages
:
List
[
"ChatComple
tionMessage
Param
"
],
tools
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
tools
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
)
->
List
[
int
]:
**
kwargs
)
->
List
[
int
]:
assert
tools
is
None
,
"`tools` are not yet supported."
assert
tools
is
None
,
"`tools` are not yet supported."
request
=
ChatCompletionRequest
(
request
=
ChatCompletionRequest
(
messages
=
conversation
)
# type: ignore[type-var]
messages
=
messages
)
# type: ignore[type-var]
encoded
=
self
.
mistral
.
encode_chat_completion
(
request
)
encoded
=
self
.
mistral
.
encode_chat_completion
(
request
)
# encode-decode to get clean prompt
# encode-decode to get clean prompt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment